Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] ENH: Add function to search for specific connections #367

Merged
merged 22 commits into from Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -38,6 +38,8 @@ Changelog

- Each drive spike train sampled through an independent process corresponds to a single artificial drive cell, the number of which users can set when adding drives with `n_drive_cells` and `cell_specific`, by `Ryan Thorpe`_ in `#383 <https://github.com/jonescompneurolab/hnn-core/pull/383>`_

- Add :func:`~hnn_core.Network.pick_connection` to query the indices of specific connections in :attr:`~hnn_core.Network.connectivity`, by `Nick Tolley`_ in `#367 <https://github.com/jonescompneurolab/hnn-core/pull/367>`_

Bug
~~~

Expand Down
23 changes: 14 additions & 9 deletions examples/howto/plot_connectivity.py
Expand Up @@ -40,17 +40,23 @@
# is a list of dictionaries which detail every cell-cell, and drive-cell
# connection. The weights of these connections can be visualized with
# :func:`~hnn_core.viz.plot_connectivity_weights` as well as
# :func:`~hnn_core.viz.plot_cell_connectivity`
# :func:`~hnn_core.viz.plot_cell_connectivity`. We can search for specific
# connections using ``pick_connection`` which returns the indices
# of ``net.connectivity`` that match the provided parameters.
from hnn_core.viz import plot_connectivity_matrix, plot_cell_connectivity
from hnn_core.network import pick_connection

print(len(net_erp.connectivity))

conn_idx = 6
conn_indices = pick_connection(
net=net_erp, src_gids='L5_basket', target_gids='L5_pyramidal',
loc='soma', receptor='gabaa')
conn_idx = conn_indices[0]
print(net_erp.connectivity[conn_idx])
plot_connectivity_matrix(net_erp, conn_idx)

gid_idx = 11
src_gid = net_erp.connectivity[conn_idx]['src_range'][gid_idx]
src_gid = net_erp.connectivity[conn_idx]['src_gids'][gid_idx]
fig = plot_cell_connectivity(net_erp, conn_idx, src_gid)

###############################################################################
Expand All @@ -70,8 +76,6 @@
# and L2 basket cells. :meth:`hnn_core.Network.add_connection` allows
# connections to be specified with either cell names, or the cell IDs (gids)
# directly.


def get_network(probability=1.0):
net = jones_2009_model(params, add_drives_from_params=True)
net.clear_connectivity()
Expand Down Expand Up @@ -115,16 +119,17 @@ def get_network(probability=1.0):

###############################################################################
# We can plot the sparse connectivity pattern between cell populations.
conn_idx = len(net_sparse.connectivity) - 1
plot_connectivity_matrix(net_sparse, conn_idx)
conn_indices = pick_connection(
net=net_sparse, src_gids='L2_basket', target_gids='L2_basket',
loc='soma', receptor='gabaa')

conn_idx = len(net_sparse.connectivity) - 2
conn_idx = conn_indices[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much better! now only thing is that you shouldn't have to do this:

Suggested change
conn_idx = conn_indices[0]

so plot_connectivity_matrix accepts a list. Can be a separate PR though. Are there other functions which benefit from pick_connection ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that'd be a bit nicer, but what do you think the alternate behavior should be with a list? Returning multiple plots or subplots?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm ... either is fine. Maybe subplots is better so user is not overwhelmed by plot windows. Each row is one connectivity matrix if multiple subplots. And you can even implement a scrollbar if there are too many subplots

plot_connectivity_matrix(net_sparse, conn_idx)

###############################################################################
# Note that the sparsity is in addition to the weight decay with distance
# from the source cell.
src_gid = net_sparse.connectivity[conn_idx]['src_range'][5]
src_gid = net_sparse.connectivity[conn_idx]['src_gids'][5]
plot_cell_connectivity(net_sparse, conn_idx, src_gid=src_gid)

###############################################################################
Expand Down
52 changes: 52 additions & 0 deletions hnn_core/check.py
@@ -0,0 +1,52 @@
"""Input check functions."""

# Authors: Nick Tolley <nicholas_tolley@brown.edu>

from .params import _long_name
from .externals.mne import _validate_type, _check_option


def _check_gids(gids, gid_ranges, valid_cells, arg_name):
"""Format different gid specifications into list of gids"""
_validate_type(gids, (int, list, range, str, None), arg_name,
'int list, range, str, or None')

# Convert gids to list
if gids is None:
return list()
if isinstance(gids, int):
gids = [gids]
elif isinstance(gids, str):
_check_option(arg_name, gids, valid_cells)
gids = gid_ranges[_long_name(gids)]

cell_type = _gid_to_type(gids[0], gid_ranges)
for gid in gids:
_validate_type(gid, int, arg_name)
gid_type = _gid_to_type(gid, gid_ranges)
if gid_type is None:
raise AssertionError(
f'{arg_name} {gid} not in net.gid_ranges')
if gid_type != cell_type:
raise AssertionError(f'All {arg_name} must be of the same type')

return gids


def _gid_to_type(gid, gid_ranges):
"""Reverse lookup of gid to type."""
for gidtype, gids in gid_ranges.items():
if gid in gids:
return gidtype


def _string_input_to_list(input_str, valid_str, arg_name):
"""Convert input strings to list"""
if input_str is None:
input_str = list()
elif isinstance(input_str, str):
input_str = [input_str]
for str_item in input_str:
_check_option(arg_name, str_item, valid_str)

return input_str
151 changes: 126 additions & 25 deletions hnn_core/network.py
Expand Up @@ -20,6 +20,7 @@
from .viz import plot_cells
from .externals.mne import _validate_type, _check_option
from .extracellular import ExtracellularArray
from .check import _check_gids, _gid_to_type, _string_input_to_list


def _create_cell_coords(n_pyr_x, n_pyr_y, zdiff=1307.4):
Expand Down Expand Up @@ -151,6 +152,119 @@ def _connection_probability(conn, probability, seed=0):
conn['gid_pairs'].pop(src_gid)


def pick_connection(net, src_gids=None, target_gids=None,
loc=None, receptor=None):
"""Returns indices of connections that match search parameters.

Parameters
----------
net : Instance of Network object
The Network object
src_gids : str | int | range | list of int | None
Identifier for source cells. Passing str arguments
('L2_pyramidal', 'L2_basket', 'L5_pyramidal', 'L5_basket') is
equivalent to passing a list of gids for the relvant cell type.
source - target connections are made in an all-to-all pattern.
target_gids : str | int | range | list of int | None
Identifer for targets of source cells. Passing str arguments
('L2_pyramidal', 'L2_basket', 'L5_pyramidal', 'L5_basket') is
equivalent to passing a list of gids for the relvant cell type.
source - target connections are made in an all-to-all pattern.
loc : str | list of str | None
Location of synapse on target cell. Must be
'proximal', 'distal', or 'soma'. Note that inhibitory synapses
(receptor='gabaa' or 'gabab') of L2 pyramidal neurons are only
valid loc='soma'.
receptor : str | list of str | None
Synaptic receptor of connection. Must be one of:
'ampa', 'nmda', 'gabaa', or 'gabab'.

Returns
-------
conn_indices : list of int
List of indices corresponding to items in net.connectivity.
Connection indices are included if any of the provided parameter
values are present in a connection.

Notes
-----
Passing a list of values to a single parameter corresponds to a
logical OR operation across indices. For example,
loc=['distal', 'proximal'] returns all connections that target
distal or proximal dendrites.

Passing multiple parameters corresponds to a logical AND operation.
For example, net.pick_connection(loc='distal', receptor='ampa')
returns only the indices of connections that target the distal
dendrites and have ampa receptors.
"""

# Convert src and target gids to lists
valid_srcs = list(net.gid_ranges.keys()) # includes drives as srcs
valid_targets = list(net.cell_types.keys())
src_gids = _check_gids(src_gids, net.gid_ranges,
valid_srcs, 'src_gids')
target_gids = _check_gids(target_gids, net.gid_ranges,
valid_targets, 'target_gids')

_validate_type(loc, (str, list, None), 'loc', 'str, list, or None')
_validate_type(receptor, (str, list, None), 'receptor',
'str, list, or None')

valid_loc = ['proximal', 'distal', 'soma']
valid_receptor = ['ampa', 'nmda', 'gabaa', 'gabab']

# Convert receptor and loc to list
loc = _string_input_to_list(loc, valid_loc, 'loc')
receptor = _string_input_to_list(receptor, valid_receptor, 'receptor')

# Create lookup dictionaries
src_dict, target_dict = dict(), dict()
loc_dict, receptor_dict = dict(), dict()
for conn_idx, conn in enumerate(net.connectivity):
# Store connections matching each src_gid
for src_gid in conn['src_gids']:
if src_gid in src_dict:
src_dict[src_gid].append(conn_idx)
else:
src_dict[src_gid] = [conn_idx]
# Store connections matching each target_gid
for target_gid in conn['target_gids']:
if target_gid in target_dict:
target_dict[target_gid].append(conn_idx)
else:
target_dict[target_gid] = [conn_idx]
# Store connections matching each location
if conn['loc'] in loc_dict:
loc_dict[conn['loc']].append(conn_idx)
else:
loc_dict[conn['loc']] = [conn_idx]
# Store connections matching each receptor
if conn['receptor'] in receptor_dict:
receptor_dict[conn['receptor']].append(conn_idx)
else:
receptor_dict[conn['receptor']] = [conn_idx]

# Look up conn indeces that match search terms and add to set.
conn_set = set()
search_pairs = [(src_gids, src_dict), (target_gids, target_dict),
(loc, loc_dict), (receptor, receptor_dict)]
for search_terms, search_dict in search_pairs:
inner_set = set()
# Union of indices which match inputs for single parameter
for term in search_terms:
inner_set = inner_set.union(search_dict[term])
# Intersection across parameters
if conn_set and inner_set:
conn_set = conn_set.intersection(inner_set)
else:
conn_set = conn_set.union(inner_set)

conn_set = list(conn_set)
conn_set.sort()
return conn_set


class Network(object):
"""The Network class.

Expand Down Expand Up @@ -925,9 +1039,7 @@ def _add_cell_type(self, cell_name, pos, cell_template=None):

def gid_to_type(self, gid):
"""Reverse lookup of gid to type."""
for gidtype, gids in self.gid_ranges.items():
if gid in gids:
return gidtype
return _gid_to_type(gid, self.gid_ranges)

def _gid_to_cell(self, gid):
"""Reverse lookup of gid to cell.
Expand Down Expand Up @@ -992,17 +1104,14 @@ def add_connection(self, src_gids, target_gids, loc, receptor,
"""
conn = _Connectivity()
threshold = self.threshold
_validate_type(src_gids, (int, list, range, str), 'src_gids',
'int list, range, or str')

_validate_type(target_gids, (int, list, range, str), 'target_gids',
'int list, range or str')
valid_cells = list(self.cell_types.keys())

# Convert src_gids to list
if isinstance(src_gids, int):
src_gids = [src_gids]
elif isinstance(src_gids, str):
_check_option('src_gids', src_gids, valid_cells)
src_gids = self.gid_ranges[_long_name(src_gids)]
src_gids = _check_gids(src_gids, self.gid_ranges,
valid_cells, 'src_gids')

# Convert target_gids to list of list, one element for each src_gid
if isinstance(target_gids, int):
Expand Down Expand Up @@ -1037,30 +1146,22 @@ def add_connection(self, src_gids, target_gids, loc, receptor,
raise AssertionError(
'All target_gids must be of the same type')
conn['target_type'] = target_type
conn['target_range'] = self.gid_ranges[_long_name(target_type)]
conn['target_gids'] = list(target_set)
conn['num_targets'] = len(target_set)

if len(target_gids) != len(src_gids):
raise AssertionError('target_gids must have a list for each src.')

# Format gid_pairs and add to conn dictionary
gid_pairs = dict()
src_type = self.gid_to_type(src_gids[0])
for src_gid, target_src_pair in zip(src_gids, target_gids):
_validate_type(src_gid, int, 'src_gid', 'int')
gid_type = self.gid_to_type(src_gid)
if gid_type is None:
raise AssertionError(
f'src_gid {src_gid} not in net.gid_ranges')
elif gid_type != src_type:
raise AssertionError('All src_gids must be of the same type')
if not allow_autapses:
mask = np.in1d(target_src_pair, src_gid, invert=True)
target_src_pair = np.array(target_src_pair)[mask].tolist()
gid_pairs[src_gid] = target_src_pair

conn['src_type'] = src_type
conn['src_range'] = self.gid_ranges[_long_name(src_type)]
conn['src_type'] = self.gid_to_type(src_gids[0])
conn['src_gids'] = list(set(src_gids))
conn['num_srcs'] = len(src_gids)

conn['gid_pairs'] = gid_pairs
Expand Down Expand Up @@ -1189,10 +1290,10 @@ class _Connectivity(dict):
Number of unique source gids.
num_targets : int
Number of unique target gids.
src_range : range
Range of gids identified by src_type.
target_range : range
Range of gids identified by target_type.
src_gids : list of int
List of unique source gids in connection.
target_gidst : list of int
List of unique target gids in connection.
loc : str
Location of synapse on target cell. Must be
'proximal', 'distal', or 'soma'. Note that inhibitory synapses
Expand Down
2 changes: 1 addition & 1 deletion hnn_core/network_builder.py
Expand Up @@ -442,7 +442,7 @@ def _connect_celltypes(self):
for conn in connectivity:
loc, receptor = conn['loc'], conn['receptor']
nc_dict = deepcopy(conn['nc_dict'])
# Gather indeces of targets on current node
# Gather indices of targets on current node
valid_targets = set()
for src_gid, target_gids in conn['gid_pairs'].items():
filtered_targets = list()
Expand Down