Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add function to visualize connectivity weights #339

Merged
merged 23 commits into from Jun 5, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 18 additions & 8 deletions examples/plot_connectivity.py
Expand Up @@ -8,7 +8,7 @@

# Author: Nick Tolley <nicholas_tolley@brown.edu>

# sphinx_gallery_thumbnail_number = 5
# sphinx_gallery_thumbnail_number = 1

import os.path as op

Expand Down Expand Up @@ -38,9 +38,19 @@
# Instantiating the network comes with a predefined set of connections that
# reflect the canonical neocortical microcircuit. ``net.connectivity``
# is a list of dictionaries which detail every cell-cell, and drive-cell
# connection.
# connection. The weights of these connections can be visualized with
# :func:`~hnn_core.viz.plot_connectivity_weights` as well as
# :func:`~hnn_core.viz.plot_cell_connectivity`
from hnn_core.viz import plot_connectivity_matrix, plot_cell_connectivity
print(len(net_erp.connectivity))
print(net_erp.connectivity[0:2])

conn_idx = 20
print(net_erp.connectivity[conn_idx])
plot_connectivity_matrix(net_erp, conn_idx)

gid_idx = 11
src_gid = net_erp.connectivity[conn_idx]['src_range'][gid_idx]
fig, ax = plot_cell_connectivity(net_erp, conn_idx, src_gid)
ntolley marked this conversation as resolved.
Show resolved Hide resolved

###############################################################################
# Data recorded during simulations are stored under
Expand Down Expand Up @@ -89,9 +99,7 @@
# activity is visible as vertical lines where several cells fire simultaneously
# We can additionally use the ``probability``. argument to create a sparse
# connectivity pattern instead of all-to-all. Let's try creating the same
# network with a 10% chance of cells connecting to each other. The resulting
# connectivity pattern can also be visualized with
# ``net.connectivity[idx].plot()``
# network with a 10% chance of cells connecting to each other.
probability = 0.1
net_sparse = default_network(params, add_drives_from_params=True)
net_sparse.clear_connectivity()
Expand All @@ -115,8 +123,10 @@
dpl_sparse = simulate_dipole(net_sparse, n_trials=1)
net_sparse.cell_response.plot_spikes_raster()

net_sparse.connectivity[-2].plot()
net_sparse.connectivity[-1].plot()
# Get index of most recently added connection
conn_idx = len(net_sparse.connectivity)
plot_connectivity_matrix(net_sparse, conn_idx, show_weight=False)
plot_connectivity_matrix(net_sparse, conn_idx - 1, show_weight=False)

###############################################################################
# Using the sparse connectivity pattern produced a lot more spiking in
Expand Down
80 changes: 66 additions & 14 deletions hnn_core/cell.py
Expand Up @@ -25,6 +25,68 @@ def _get_cos_theta(p_secs, sec_name_apical):
return cos_thetas


def _calculate_gaussian(x_val, height, lamtha):
ntolley marked this conversation as resolved.
Show resolved Hide resolved
"""Return height of gaussian at x_val.

Parameters
----------
x_val : float
Value on x-axis to query height of gaussian curve.
height : float
Height of the gaussian curve at zero.
lamtha : float
Space constant.

Returns
-------
x_height : float
Height of gaussian at x_val.

Notes
-----
Gaussian curve is centered at zero and has a fixed peak height
such the _calculate_gaussian(0, lamtha) returns 1 for all lamtha.
"""
x_height = height * np.exp(-(x_val**2) / (lamtha**2))

return x_height


def _get_gaussian_connection(src_pos, target_pos, nc_dict):
"""Calculate distance dependent connection properties.

Parameters
----------
src_pos : float
Position of source cell.
target_pos : float
Position of target cell.
nc_dict : dict
Dictionary with keys: pos_src, A_weight, A_delay, lamtha
Defines the connection parameters

Returns
-------
weight : float
Weight of the synaptic connection.
delay : float
Delay of synaptic connection.

Notes
-----
Distance in xy plane is used for gaussian decay.
"""
x_dist = target_pos[0] - src_pos[0]
y_dist = target_pos[1] - src_pos[1]
cell_dist = np.sqrt(x_dist**2 + y_dist**2)
ntolley marked this conversation as resolved.
Show resolved Hide resolved

weight = _calculate_gaussian(
cell_dist, nc_dict['A_weight'], nc_dict['lamtha'])
delay = nc_dict['A_delay'] / _calculate_gaussian(
cell_dist, 1, nc_dict['lamtha'])
return weight, delay


class _ArtificialCell:
"""The ArtificialCell class for initializing a NEURON feed source.

Expand Down Expand Up @@ -477,24 +539,14 @@ def parconnect_from_src(self, gid_presyn, nc_dict, postsyn):
from .network_builder import _PC

nc = _PC.gid_connect(gid_presyn, postsyn)
# calculate distance between cell positions with pardistance()
d = self._pardistance(nc_dict['pos_src'])
# set props here

# set props here.
nc.threshold = nc_dict['threshold']
nc.weight[0] = nc_dict['A_weight'] * \
np.exp(-(d**2) / (nc_dict['lamtha']**2))
nc.delay = nc_dict['A_delay'] / \
(np.exp(-(d**2) / (nc_dict['lamtha']**2)))
nc.weight[0], nc.delay = _get_gaussian_connection(
nc_dict['pos_src'], self.pos, nc_dict)

return nc

# pardistance function requires pre position, since it is
# calculated on POST cell
def _pardistance(self, pos_pre):
dx = self.pos[0] - pos_pre[0]
dy = self.pos[1] - pos_pre[1]
return np.sqrt(dx**2 + dy**2)

def plot_morphology(self, ax=None, cell_types=None, show=True):
"""Plot the cell morphology.

Expand Down
21 changes: 1 addition & 20 deletions hnn_core/network.py
Expand Up @@ -17,7 +17,7 @@
from .cells_default import pyramidal, basket
from .cell_response import CellResponse
from .params import _long_name, _short_name
from .viz import plot_cells, plot_connectivity_matrix
from .viz import plot_cells
from .externals.mne import _validate_type, _check_option


Expand Down Expand Up @@ -1268,25 +1268,6 @@ def __repr__(self):

return entr

def plot(self, ax=None, show=True):
"""Plot connectivity matrix for instance of _Connectivity object.

Parameters
----------
ax : instance of matplotlib Axes3D | None
An axis object from matplotlib. If None,
a new figure is created.
show : bool
If True, show the figure.

Returns
-------
fig : instance of matplotlib Figure
The matplotlib figure handle.
"""

return plot_connectivity_matrix(self, ax=ax, show=show)


class _NetworkDrive(dict):
"""A class for containing the parameters of external drives
Expand Down
21 changes: 21 additions & 0 deletions hnn_core/tests/test_viz.py
Expand Up @@ -7,6 +7,7 @@
import hnn_core
from hnn_core import read_params, default_network
from hnn_core.viz import plot_cells, plot_dipole, plot_psd, plot_tfr_morlet
from hnn_core.viz import plot_connectivity_matrix, plot_cell_connectivity
from hnn_core.dipole import simulate_dipole

matplotlib.use('agg')
Expand All @@ -24,6 +25,26 @@ def test_network_visualization():
ax = net.cell_types['L2_pyramidal'].plot_morphology()
assert len(ax.lines) == 8

conn_idx = 0
with pytest.raises(TypeError, match='net must be an instance of'):
plot_connectivity_matrix('blah', conn_idx, show_weight=False)

with pytest.raises(TypeError, match='conn_idx must be an instance of'):
plot_connectivity_matrix(net, 'blah', show_weight=False)

with pytest.raises(TypeError, match='show_weight must be an instance of'):
plot_connectivity_matrix(net, conn_idx, show_weight='blah')

src_gid = 5
with pytest.raises(TypeError, match='net must be an instance of'):
plot_cell_connectivity('blah', conn_idx, src_gid=src_gid)

with pytest.raises(TypeError, match='conn_idx must be an instance of'):
plot_cell_connectivity(net, 'blah', src_gid=src_gid)

with pytest.raises(TypeError, match='src_gid must be an instance of'):
plot_cell_connectivity(net, conn_idx, src_gid='blah')
ntolley marked this conversation as resolved.
Show resolved Hide resolved


def test_dipole_visualization():
"""Test dipole visualisations."""
Expand Down