Skip to content

Commit

Permalink
Some more progress
Browse files Browse the repository at this point in the history
  • Loading branch information
jasmainak committed Jun 20, 2020
1 parent f801832 commit 8721db2
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 32 deletions.
17 changes: 15 additions & 2 deletions hnn_core/basket.py
Expand Up @@ -66,8 +66,21 @@ def parconnect(self, gid, gid_dict, pos_dict, p):

# this function might make more sense as a method of net?
# par: receive from external inputs
def parreceive(self, gid, gid_dict, pos_dict, p_ext):
# for some gid relating to the input feed:
def parreceive(self, gids, pos_dict, p_ext):
"""Receive from common inputs.
Parameters
----------
gids : dict
A dictionary of the cell types in the network and their
corresponding range of cell IDs
pos_dict : dict
The position dictionary
feeds : list
A list of Rhythmic instances
threshold : float
The spiking threshold
"""
for gid_src, p_src, pos in zip(gid_dict['common'],
p_ext, pos_dict['common']):
# check if AMPA params are defined in the p_src
Expand Down
16 changes: 14 additions & 2 deletions hnn_core/feed.py
Expand Up @@ -142,12 +142,24 @@ def set_prox_connections_layer2(self, g_pyr_ampa, g_pyr_nmda,
Parameters
----------
g_pyr_ampa : float (in uS)
Postsynaptic conductance value of the pyramidal AMPA receptor for the
layer 2 rhythmic input.
Postsynaptic conductance value of the pyramidal AMPA receptor
for the layer 2 rhythmic input.
g_pyr_nmda : float (in uS)
Postsynaptic conductance value of the NMDA receptor for the
external rhythmic input
"""
self.L2_Pyr_ampa['g'] = g_pyr_ampa
self.L2_Pyr_ampa['delay'] = delay
self.L2_Pyr_ampa['lamtha'] = 100.

self.L2_Pyr_nmda['g'] = g_pyr_nmda
self.L2_Pyr_nmda['delay'] = delay
self.L2_Pyr_ampa['lamtha'] = 100.

self.L2_Basket_ampa['g'] = g_basket_ampa
self.L2_Basket_ampa['g'] = delay
self.L2_Basket_ampa['lamtha'] = 100.
pass

def set_prox_connections_layer5(self, g_pyr_ampa, g_pyr_nmda,
g_basket_ampa, g_basket_nmda, delay):
Expand Down
25 changes: 15 additions & 10 deletions hnn_core/network.py
Expand Up @@ -12,7 +12,6 @@
from .feed import ExtFeed
from .pyramidal import L2Pyr, L5Pyr
from .basket import L2Basket, L5Basket
from .params import create_pext


def read_spikes(fname, gid_dict=None):
Expand Down Expand Up @@ -70,6 +69,8 @@ class Network(object):
The parameters
n_jobs : int
The number of jobs to run in parallel
feeds_common : instance of Rhythmic
The rhythmic feeds as input to the network
Attributes
----------
Expand All @@ -88,7 +89,7 @@ class Network(object):
An instance of the Spikes object.
"""

def __init__(self, params, n_jobs=1):
def __init__(self, params, feeds_common=None, n_jobs=1):
from .parallel import create_parallel_context
# setup simulation (ParallelContext)
create_parallel_context(n_jobs=n_jobs)
Expand Down Expand Up @@ -124,9 +125,7 @@ def __init__(self, params, n_jobs=1):
# Global number of external inputs ... automatic counting
# makes more sense
# p_unique represent ext inputs that are going to go to each cell
self.p_common, self.p_unique = create_pext(self.params,
self.params['tstop'])
self.n_common_feeds = len(self.p_common)

# Source list of names
# in particular order (cells, common, names of unique inputs)
self.src_list_new = self._create_src_list()
Expand Down Expand Up @@ -155,7 +154,9 @@ def __init__(self, params, n_jobs=1):
self._gid_assign()
# create cells (and create self.origin in create_cells_pyr())
self.cells = []
self.common_feeds = []

self.feeds_common = feeds_common

# external unique input list dictionary
self.unique_feeds = dict.fromkeys(self.p_unique)
# initialize the lists in the dict
Expand Down Expand Up @@ -269,7 +270,7 @@ def _create_coords_common_feeds(self):
origin_z = np.floor(self.zdiff / 2)
self.origin = (origin_x, origin_y, origin_z)
self.pos_dict['common'] = [self.origin for i in
range(self.n_common_feeds)]
range(len(self.feeds_common))]
# at this time, each of the unique inputs is per cell
for key in self.p_unique.keys():
# create the pos_dict for all the sources
Expand Down Expand Up @@ -329,7 +330,7 @@ def _gid_assign(self):
pc.set_gid2node(gid_input, rank)
self._gid_list.append(gid_input)

for gid_base in range(rank, self.n_common_feeds, nhosts):
for gid_base in range(rank, len(self.feeds_common), nhosts):
# shift the gid_base to the common gid
gid = gid_base + self.gid_dict['common'][0]
# set as usual
Expand Down Expand Up @@ -391,6 +392,9 @@ def _create_all_spike_sources(self):
pc.cell(gid, self.cells[-1].connect_to_target(
None, self.params['threshold']))

for feed in self.feeds:
pc.cell(gid, feed.connect_to_target(self.params['threshold']))

# external inputs are special types of artificial-cells
# 'common': all cells impacted with identical TIMING of spike
# events. NB: cell types can still have different weights for how
Expand Down Expand Up @@ -455,8 +459,9 @@ def _parnet_connect(self):
# parconnect receives connections from other cells
# parreceive receives connections from common external inputs
cell.parconnect(gid, self.gid_dict, self.pos_dict, self.params)
cell.parreceive(gid, self.gid_dict,
self.pos_dict, self.p_common)
cell.parreceive(gids=self.gid_dict['common'],
self.pos_dict, feed=self.feed_common,
threshold=self.params['threshold'])
# now do the unique external feeds specific to these cells
# parreceive_ext receives connections from UNIQUE
# external inputs
Expand Down
47 changes: 29 additions & 18 deletions hnn_core/pyramidal.py
Expand Up @@ -872,24 +872,35 @@ def parconnect(self, gid, gid_dict, pos_dict, p):
'L2_basket', 'L2Basket', lamtha=50.,
postsyns=[self.apicaltuft_gabaa])

# receive from common inputs
# XXX check NetCon connections for proximal inputs with zero weights
def parreceive(self, gid, gid_dict, pos_dict, p_ext):
for gid_src, p_src, pos in zip(gid_dict['common'],
p_ext, pos_dict['common']):
def parreceive(self, gids, pos_dict, feeds, threshold):
"""Receive from common inputs.
Parameters
----------
gids : dict
The cell IDs of the common inputs
pos_dict : dict
The position dictionary
feeds : list
A list of Rhythmic instances
threshold : float
The spiking threshold
"""
for gid_src, feed, pos in zip(gids, feeds, pos_dict['common']):
# Check if AMPA params defined in p_src
if 'L5Pyr_ampa' in p_src.keys():
if hasattr(feed, 'L5Pyr_ampa'):
nc_dict_ampa = {
'pos_src': pos,
'A_weight': p_src['L5Pyr_ampa'][0],
'A_delay': p_src['L5Pyr_ampa'][1],
'lamtha': p_src['lamtha'],
'threshold': p_src['threshold'],
'A_weight': feed.L5Pyr_ampa['g'],
'A_delay': feed.L5Pyr_ampa['delay'],
'lamtha': feed.L5Pyr_ampa['lamtha'],
'threshold': threshold,
'type_src': 'ext'
}

# Proximal feed AMPA synapses
if p_src['loc'] == 'proximal':
if feed.loc == 'proximal':
# basal2_ampa, basal3_ampa, apicaloblique_ampa
self.ncfrom_common.append(
self.parconnect_from_src(gid_src, nc_dict_ampa,
Expand All @@ -901,25 +912,25 @@ def parreceive(self, gid, gid_dict, pos_dict, p_ext):
self.parconnect_from_src(gid_src, nc_dict_ampa,
self.apicaloblique_ampa))
# Distal feed AMPA synsapes
elif p_src['loc'] == 'distal':
elif feed.loc == 'distal':
# apical tuft
self.ncfrom_common.append(
self.parconnect_from_src(gid_src, nc_dict_ampa,
self.apicaltuft_ampa))

# Check if NMDA params defined in p_src
if 'L5Pyr_nmda' in p_src.keys():
if hasattr(feed, 'L5Pyr_nmda'):
nc_dict_nmda = {
'pos_src': pos,
'A_weight': p_src['L5Pyr_nmda'][0],
'A_delay': p_src['L5Pyr_nmda'][1],
'lamtha': p_src['lamtha'],
'threshold': p_src['threshold'],
'A_weight': feed.L5Pyr_nmda['g'],
'A_delay': feed.L5Pyr_nmda['delay'],
'lamtha': feed.L5Pyr_nmda['lamtha'],
'threshold': threshold,
'type_src': 'ext'
}

# Proximal feed NMDA synapses
if p_src['loc'] == 'proximal':
if feed.loc == 'proximal':
# basal2_nmda, basal3_nmda, apicaloblique_nmda
self.ncfrom_common.append(
self.parconnect_from_src(
Expand All @@ -931,7 +942,7 @@ def parreceive(self, gid, gid_dict, pos_dict, p_ext):
self.parconnect_from_src(
gid_src, nc_dict_nmda, self.apicaloblique_nmda))
# Distal feed NMDA synsapes
elif p_src['loc'] == 'distal':
elif feed.loc == 'distal':
# apical tuft
self.ncfrom_common.append(
self.parconnect_from_src(
Expand Down

0 comments on commit 8721db2

Please sign in to comment.