In [1]:
%matplotlib ipympl
# Basic libraries for data manipulation
import multiprocessing
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import pickle
import multiprocessing as mp

# Neuron libraries
from neuron import h
from neuron import load_mechanisms
import neuron

import ipywidgets as wdg
from tqdm import tqdm, tqdm_notebook

#Load main NEURON modules
load_mechanisms('/work/nrn-7.4/x86_64/bin/')
h.load_file('stdrun.hoc')
h('objref nil')

class BallAndStick(object):
    
    def __init__(self, E_PAS = -75.0, Rm = 10000.0, Cm = 1.0, Ra = 150.0, celsius = 23, dend_nseg = 11, 
                 soma_diam = 25, dend_length = 25, dend_diam = 1):
         # Initialise ephys parameters
        self.E_PAS = E_PAS
        self.Rm = Rm
        self.Cm = Cm
        self.Ra = Ra
        self.CELSIUS = celsius
        
        # Create soma and dendrite and connect them together
        self.soma = h.Section(name="soma")
        self.dend = h.Section(name="dend")
        self.dend.connect(self.soma(1))
        
        # Initialise soma and dendrite diameters
        self.dend.L = dend_length # This makes soma 500 microns squared
        self.soma.diam = soma_diam
        self.dend.diam = dend_diam
        self.soma.L = soma_diam
        
        # Set dendritic segments
        self.dend.nseg = dend_nseg
        
        # Insert conductances
        self.soma = self.add_conductances(self.soma)
        self.dend = self.add_conductances(self.dend)
        
        # Define variables to be overwritten
        self.activation_pattern = []
        self.AMPA_ncs = []
        
        # Insert tapering of dendrite
#         self.dend = self.taper_diam(self.dend,2,1)
    
    def taper_diam(self, sec,zero_bound,one_bound):
        for (num_sec, d) in zip(sec, np.linspace(zero_bound, one_bound, sec.nseg)):
            num_sec.diam=d
            
        return sec

    def add_conductances(self, nrn_sec):
        nrn_sec.insert('pas')
        nrn_sec.Ra = self.Ra
        nrn_sec.e_pas = self.E_PAS
        nrn_sec.g_pas = 1.0/self.Rm
        for seg in nrn_sec:
            seg.cm = self.Cm
      
        return nrn_sec
    
    def add_AMPA(self, func = h.Exp2Syn, section = h.Section(), locs = [0.5], gmax = 0.5, tau1 = 0.1, tau2 = 1 ):
        self.AMPA_syns, self.AMPA_ncs = [], []
        for syn_no in range(len(locs)):
            SYN = func(float(locs[syn_no]), sec = section)
            SYN.tau1 = tau1
            SYN.tau2 = tau2
            if type(gmax) == list or type(gmax) == np.ndarray:
                NC = h.NetCon(h.nil, SYN, 0, 0, gmax[syn_no])
            else:
                NC = h.NetCon(h.nil, SYN, 0, 0, gmax)
                
            self.AMPA_syns.append(SYN), self.AMPA_ncs.append(NC)
            
    def add_NMDA(self,locs=[0.5],gmax=[1],rel = [20]):
        self.NMDAlist = []
        self.preNMDA_list = []
        for loc in enumerate(locs):
            PRE = h.Section()
            PRE.diam = 1.0 ; PRE.L=1.0
            PRE.insert('rel')
            PRE.dur_rel = 0.5
            PRE.amp_rel = 2.0
            PRE.del_rel = rel[loc[0]]
            NMDA = h.NMDA_Mg_T(self.dend(loc[1]))
            NMDA.gmax = gmax[loc[0]]
            h.setpointer(PRE(0.5).rel._ref_T,'C',NMDA)
            self.preNMDA_list.append(PRE)
            self.NMDAlist.append(NMDA)
            
    def simulate(self, v_init = -75, t_stop=200, NMDA=False):
        """ Run the main simulation. Accepts AMPA only or with NMDA depending on passing of the NMDA parameters.
        Records from soma."""
        
        self.vec = {}
        for type_rec in "vrec","trec":
            self.vec[type_rec] = h.Vector()

        self.vec["trec"].record(h._ref_t)
        self.vec["vrec"].record(self.soma(0.5)._ref_v)
        
        if NMDA==True:
            self.NMDAgrec, self.NMDAirec = [], []
            for chan in np.arange(0, len(self.NMDAlist)):
                loc = self.NMDAlist[chan].get_loc()
                h.pop_section()
                self.NMDAgrec.append(h.Vector())
                self.NMDAgrec[chan].record(self.NMDAlist[chan]._ref_g)
                self.NMDAirec.append(h.Vector())
                self.NMDAirec[chan].record(self.NMDAlist[chan]._ref_i)
        
        h.celsius = self.CELSIUS
        h.finitialize(v_init)
        neuron.run(t_stop)
        
    def netcon_events(self):
        for syn_event in self.activation_pattern:
            self.AMPA_ncs[syn_event[0]].event(float(syn_event[1]))   
        
    def run_IN_OUT(self, base_AMPA = 0.0005, base_NMDA = 8000, synapses = 16, base_step = 2, 
                   nmda_gradient_top = 1, nmda_gradient_bot = 1, ampa_gradient_top = 1, ampa_gradient_bot = 1,
                   IN_scale_ampa = 1, IN_scale_nmda = 1, OUT_scale_ampa = 1, OUT_scale_nmda = 1,
                   syn_placement_bot = 0, syn_placement_top = 1, dendrite = 13, base_time = 20):
        
        # First reconstruct parameters
        syn_placement = np.linspace(syn_placement_bot,syn_placement_top,synapses)
                    
        for seq_type in ["IN","OUT"]:

            # Define synapse placement for IN and OUT sequences
            if seq_type == "IN":
                syn_sequence = np.linspace(base_time,base_time+synapses*base_step,synapses)[::-1]
            if seq_type == "OUT":
                syn_sequence = np.linspace(base_time,base_time+synapses*base_step,synapses)

            self.activation_pattern = enumerate(syn_sequence)
            
            ampa_gradient = np.linspace(1,1,synapses)
            nmda_gradient = np.linspace(1,1,synapses)
            
            # Scale synapses 
            ampa_gmax = base_AMPA*ampa_gradient*IN_scale_ampa
            nmda_gmax = base_NMDA*nmda_gradient*IN_scale_nmda

            #Add synapses
            self.add_AMPA(locs=syn_placement,gmax=ampa_gmax,tau1=0.1,tau2=1, section =self.dend)
            self.add_NMDA(locs=syn_placement,gmax=nmda_gmax,rel=syn_sequence)

            # Run the simulation and plot
            fih = h.FInitializeHandler(1,self.netcon_events)
            self.simulate(t_stop=500,NMDA=True)

            if seq_type == "IN":
                IN = np.array(self.vec["vrec"])

            if seq_type == "OUT":
                OUT = np.array(self.vec["vrec"])
                
        return IN,OUT
    

## Re-run simulation

In [None]:
import itertools

class Simulator(object):
    
    def __init__(self, simulation_params):
        
        # Initialise parameters
        self.basal_params =  {"E_PAS":-75.0, "Rm":10000.0, "Cm": 1.0, "Ra":150.0, "celsius":23, "dend_nseg":11}
        self.simulation_params = simulation_params
        
        # Identify variables that need to be passed to different functions
        self.model_vars = ["dend_length","dend_diam"]
        self.model_runs = ["base_AMPA","base_NMDA","synapses","base_step"]
        
        self._get_combinations()
        
    def _get_combinations(self):
        
        allNames = sorted(self.simulation_params)
        combinations = itertools.product(*(self.simulation_params[Name] for Name in allNames))
        self.combinations = pd.DataFrame(list(combinations), columns = allNames)
    
    def _run_simulation(self,array):
        # Get iterative variables
        model_sims = array[self.model_vars].to_dict(orient="records")
        model_runs = array[self.model_runs].to_dict(orient="records")
                
        rows = []
        for count in tqdm(range(array.shape[0])):
            model = BallAndStick(**model_sims[count])
            
            IN, OUT = model.run_IN_OUT(**model_runs[count])
            
            temp = array.iloc[count].to_dict()
            temp["IN"] = np.max(IN)
            temp["OUT"] = np.max(OUT)
            rows.append(temp)
            
        return rows
            
    def runner(self, processes = 4):

        arrays = np.array_split(self.combinations, processes)
        with mp.Pool(processes = processes) as pool:
            self.data = pool.map(self._run_simulation,arrays)
            self.data = pd.concat([pd.DataFrame(i) for i in self.data])
            
        
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
    
simulation_params = {"base_AMPA": np.linspace(0.0001,0.0015,20), "base_NMDA": np.linspace(1000,16000,20), 
                     "synapses":list(np.arange(8,26).astype(int)), "dend_length":[25,50,75,100,200], 
                     "dend_diam":[0.5, 1, 2],"base_step":[0.1,1,2,4,8,10]}

sim = Simulator(simulation_params)
sim.runner(processes = 4)
    

 31%|███       | 4/13 [00:06<00:15,  1.72s/it]

In [5]:
sim.data.head()

Unnamed: 0,IN,OUT,base_AMPA,base_NMDA,dend_diam,dend_length,synapses
0,-72.570328,-72.583587,0.0001,1000.0,0.5,25.0,8.0
1,-72.321756,-72.337289,0.0001,1000.0,0.5,25.0,9.0
2,-72.079001,-72.096763,0.0001,1000.0,0.5,25.0,10.0
3,-71.84322,-71.862592,0.0001,1000.0,0.5,25.0,11.0
4,-71.610235,-71.632971,0.0001,1000.0,0.5,25.0,12.0


### Incomplete - Work in progress !!!

In [None]:
class BS_heatmap:
    def __init__(self, file_path, base_step = 8.0, dend_length = 50.0, soma_diam=50.0,dend_diam = 1,split_analysis = False, 
                 min_epsp_threshold=0.2, max_epsp_threshold=0.5):
        self.xdata = []
        self.ydata = []
        self.current_point = [] # Variable to store selected point on a graph
        self.hm_ranges = {}
        self.cbar = []
        
        self.dend_length = dend_length
        self.soma_diam = soma_diam
        
        # Set thresholds
        self.min_epsp_threshold, self.max_epsp_threshold = min_epsp_threshold, max_epsp_threshold
        
        # Initialise data and other figures
        self.df = pd.read_csv(file_path)
        self.fig, (self.ax_heat, self.ax_line) = plt.subplots(nrows=1,ncols=2)
        
        self.heatmap_plot(base_step = base_step,dend_length=dend_length,soma_diam=soma_diam,
                          split_analysis = split_analysis)

        self.cid = self.fig.canvas.mpl_connect("button_press_event",self)

        
    def __call__(self,event):
        # Initialise first point or update if first point has been plotted
        if self.current_point != []:
            self.current_point[0].set_xdata(event.xdata)
            self.current_point[0].set_ydata(event.ydata)
        else:
            self.current_point = self.ax_heat.plot(event.xdata,event.ydata,"mo",ms=5,mec="none")
          
        ampa = ((self.hm_ranges["ampa"][1]-self.hm_ranges["ampa"][0])/50)*event.ydata+self.hm_ranges["ampa"][0]
        nmda = ((self.hm_ranges["nmda"][1]-self.hm_ranges["nmda"][0])/50)*event.xdata+self.hm_ranges["nmda"][0]
        
        self.ax_line.cla()        
        
        # Setup variables
        MODEL.dend.L = self.dend_length
        MODEL.soma.diam = self.soma_diam
        MODEL.soma.L = self.soma_diam
        IN, OUT = MODEL.run_IN_OUT(base_AMPA=ampa, base_NMDA=nmda, base_step=self.base_step)
        print("Soma L:",MODEL.soma.L, "Soma diam:",MODEL.soma.diam,"MODEL.dend.L",MODEL.dend.L,"MODEL.dend.diam:",
             MODEL.dend.diam)
        
        # Show on plot        
        self.ax_line.plot(IN, "r")
        self.ax_line.plot(OUT, "b")
        
    def heatmap_plot(self, base_step = 8.0, dend_length = 50.0, soma_diam = 50.0,split_analysis = False):
        
        self.hm_data = self.df[(self.df["base_step"] == float(base_step)) & (self.df["length"] == float(dend_length)) 
                               & (self.df["soma_diam"] == float(soma_diam))]
        
        #### No dend diam in passed arguments!!!!!!
        
        # Set global variables
        self.base_step, self.length, self.soma_diam = base_step, dend_length, soma_diam
        
        self.cbar = self.ax_heat.imshow(self.hm_data["diff"].values.reshape(50,50))
        self.fig.colorbar(self.cbar)
        
        # Find the current ranges of the plot
        self.hm_ranges["ampa"] = (self.df["ampa"].min(), self.df["ampa"].max())
        self.hm_ranges["nmda"] = (self.df["nmda"].min(), self.df["nmda"].max())
        
        # Add appropriate ranges to the plot
        self.ax_heat.set_xticks([0,50]),self.ax_heat.set_yticks([0,50])
        self.ax_heat.set_xticklabels(self.hm_ranges["nmda"]), self.ax_heat.set_yticklabels(self.hm_ranges["ampa"])
        
        if split_analysis:
            self.parallel_run_accepted()
            masked_data = self.accepted_vector.reshape(50,50)
            masked_data = np.ma.masked_where(masked_data > 0.9, masked_data)
            self.ax_heat.imshow(masked_data, alpha = 0.7)
            
        self.fig.canvas.draw() 
    
    def __run_split_analysis__(self,splitted_df,potentiation_params={"min":2.0, "max":25.0}):
        accepted_df = []
        for row in tqdm(splitted_df.iterrows()):
            
            # Get simulations
            epsp = np.max(MODEL.run_IN_OUT(base_AMPA=row[1]["ampa"], base_NMDA=row[1]["nmda"],synapses=1)) + 75
            
            # Test if epsp mini is within accepted min and max parameters
            test_epsp = not (epsp < self.min_epsp_threshold) | (epsp > self.max_epsp_threshold)
            
            # Add all the parameters to the data frame
            accepted_df.append([row[1]["ampa"],row[1]["nmda"],epsp,test_epsp,
                                MODEL.dend.L, MODEL.soma.L,MODEL.dend.diam, MODEL.soma.diam,
                                row[1]["IN"],row[1]["OUT"]])
            
        df = pd.DataFrame(accepted_df, columns = ["ampa","nmda","epsp","mini_threshold",
                                                    "dend_length","soma_length","dend_diam","soma_diam","IN","OUT"])
        
        # Create accepted vector for all tested conditions
        df["threshold"] = df["mini_threshold"] 
        
        return df
    
    def parallel_run_accepted(self):
        """ Create a masked array to overlay for discovered heatmap to show appropriate EPSP values."""
        
        cores = mp.cpu_count()
        pool = mp.Pool(processes=cores)
        self.accepted_df = pd.concat(pool.map(self.__run_split_analysis__,np.array_split(self.hm_data,cores)))
        self.accepted_vector = self.accepted_df["threshold"]
        
        pool.close()
        
def netcon_events():
    """ Launch AMPA synapses with netcon events. """
    for syn_event in MODEL.activation_pattern:
        MODEL.AMPA_ncs[syn_event[0]].event(float(syn_event[1]))

MODEL = BallAndStick(dend_length=100, soma_diam=50, dend_diam=1)
bs_hm = BS_heatmap("/shared/simulation_results_v2.csv", base_step = 2.0, dend_length = 100, split_analysis=True)
        

  mplDeprecation)


1918it [02:45, 11.56it/s]

## Simple example of multiprocessing

In [16]:
import multiprocessing as mp

def __run_split_analysis__(self,splitted_df):
    accepted_vector = []
    for row in tqdm(splitted_df.iterrows()):
        epsp = np.max(bs.run_IN_OUT(base_AMPA=row[1]["ampa"], base_NMDA=row[1]["nmda"], synapses = 1)) + 75
        if (epsp < min_epsp_threshold) | (epsp > max_epsp_threshold):
            accepted_vector.append(False)
        else:
            accepted_vector.append(True)

    return accepted_vector


min_epsp_threshold, max_epsp_threshold = 0.2, 0.5

cores = mp.cpu_count()
pool = mp.Pool(processes=cores)

accepted_vector = pool.map(run_split_analysis,np.array_split(bs_hm.hm_data,cores))
accepted_vector = [item for sublist in accepted_vector for item in sublist]

pool.close()


NameError: name 'run_split_analysis' is not defined

In [None]:
from skimage.transform import resize
from skimage.io import imread

fig, ax = plt.subplots()
img=imread('/vagrant/parrot.jpg')
imgplot = ax.imshow(resize(img, (50,50)))

masked_data = bs_hm.accepted_vector.reshape(50,50)
masked_data = np.ma.masked_where(masked_data < 0.9, masked_data)
ax.imshow(masked_data, alpha = 0.7)
