diff --git a/doc/api.rst b/doc/api.rst index 075bab4e1..650d19552 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -46,6 +46,8 @@ Visualization (:py:mod:`hnn_core.viz`): plot_cell_morphology plot_psd plot_tfr_morlet + plot_cell_connectivity + plot_connectivity_matrix Parallel backends (:py:mod:`hnn_core.parallel_backends`): --------------------------------------------------------- diff --git a/doc/whats_new.rst b/doc/whats_new.rst index de686b49e..9ad3bf26a 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -24,6 +24,8 @@ Changelog - Add probability argument to :func:`~hnn_core.Network.add_connection`. Connectivity patterns can also be visualized with :func:`~hnn_core.viz.plot_connectivity_matrix`, by `Nick Tolley`_ in `#318 `_ +- Add function to visualize connections originating from individual cells :func:`~hnn_core.viz.plot_cell_connectivity`, by `Nick Tolley`_ in `#339 `_ + Bug ~~~ diff --git a/examples/plot_connectivity.py b/examples/plot_connectivity.py index 5c7824e98..29884999b 100644 --- a/examples/plot_connectivity.py +++ b/examples/plot_connectivity.py @@ -8,7 +8,7 @@ # Author: Nick Tolley -# sphinx_gallery_thumbnail_number = 5 +# sphinx_gallery_thumbnail_number = 2 import os.path as op @@ -38,9 +38,19 @@ # Instantiating the network comes with a predefined set of connections that # reflect the canonical neocortical microcircuit. ``net.connectivity`` # is a list of dictionaries which detail every cell-cell, and drive-cell -# connection. +# connection. The weights of these connections can be visualized with +# :func:`~hnn_core.viz.plot_connectivity_weights` as well as +# :func:`~hnn_core.viz.plot_cell_connectivity` +from hnn_core.viz import plot_connectivity_matrix, plot_cell_connectivity print(len(net_erp.connectivity)) -print(net_erp.connectivity[0:2]) + +conn_idx = 20 +print(net_erp.connectivity[conn_idx]) +plot_connectivity_matrix(net_erp, conn_idx) + +gid_idx = 11 +src_gid = net_erp.connectivity[conn_idx]['src_range'][gid_idx] +fig, ax = plot_cell_connectivity(net_erp, conn_idx, src_gid) ############################################################################### # Data recorded during simulations are stored under @@ -89,9 +99,7 @@ # activity is visible as vertical lines where several cells fire simultaneously # We can additionally use the ``probability``. argument to create a sparse # connectivity pattern instead of all-to-all. Let's try creating the same -# network with a 10% chance of cells connecting to each other. The resulting -# connectivity pattern can also be visualized with -# ``net.connectivity[idx].plot()`` +# network with a 10% chance of cells connecting to each other. probability = 0.1 net_sparse = default_network(params, add_drives_from_params=True) net_sparse.clear_connectivity() @@ -115,8 +123,16 @@ dpl_sparse = simulate_dipole(net_sparse, n_trials=1) net_sparse.cell_response.plot_spikes_raster() -net_sparse.connectivity[-2].plot() -net_sparse.connectivity[-1].plot() +# Get index of most recently added connection, and a src_gid in src_range. +conn_idx, gid_idx = len(net_sparse.connectivity) - 1, 5 +src_gid = net_sparse.connectivity[conn_idx]['src_range'][gid_idx] +plot_connectivity_matrix(net_sparse, conn_idx) +plot_cell_connectivity(net_sparse, conn_idx, src_gid) + +conn_idx, gid_idx = len(net_sparse.connectivity) - 2, 5 +src_gid = net_sparse.connectivity[conn_idx]['src_range'][gid_idx] +plot_connectivity_matrix(net_sparse, conn_idx) +plot_cell_connectivity(net_sparse, conn_idx, src_gid) ############################################################################### # Using the sparse connectivity pattern produced a lot more spiking in diff --git a/hnn_core/cell.py b/hnn_core/cell.py index d1d6dff2a..f5464e1be 100644 --- a/hnn_core/cell.py +++ b/hnn_core/cell.py @@ -25,6 +25,68 @@ def _get_cos_theta(p_secs, sec_name_apical): return cos_thetas +def _calculate_gaussian(x_val, height, lamtha): + """Return height of gaussian at x_val. + + Parameters + ---------- + x_val : float + Value on x-axis to query height of gaussian curve. + height : float + Height of the gaussian curve at zero. + lamtha : float + Space constant. + + Returns + ------- + x_height : float + Height of gaussian at x_val. + + Notes + ----- + Gaussian curve is centered at zero and has a fixed peak height + such the _calculate_gaussian(0, lamtha) returns 1 for all lamtha. + """ + x_height = height * np.exp(-(x_val**2) / (lamtha**2)) + + return x_height + + +def _get_gaussian_connection(src_pos, target_pos, nc_dict): + """Calculate distance dependent connection properties. + + Parameters + ---------- + src_pos : float + Position of source cell. + target_pos : float + Position of target cell. + nc_dict : dict + Dictionary with keys: pos_src, A_weight, A_delay, lamtha + Defines the connection parameters + + Returns + ------- + weight : float + Weight of the synaptic connection. + delay : float + Delay of synaptic connection. + + Notes + ----- + Distance in xy plane is used for gaussian decay. + """ + x_dist = target_pos[0] - src_pos[0] + y_dist = target_pos[1] - src_pos[1] + cell_dist = np.sqrt(x_dist**2 + y_dist**2) + + weight = _calculate_gaussian( + cell_dist, nc_dict['A_weight'], nc_dict['lamtha']) + delay = nc_dict['A_delay'] / _calculate_gaussian( + cell_dist, 1, nc_dict['lamtha']) + return weight, delay + + class _ArtificialCell: """The ArtificialCell class for initializing a NEURON feed source. @@ -477,24 +539,14 @@ def parconnect_from_src(self, gid_presyn, nc_dict, postsyn): from .network_builder import _PC nc = _PC.gid_connect(gid_presyn, postsyn) - # calculate distance between cell positions with pardistance() - d = self._pardistance(nc_dict['pos_src']) - # set props here + + # set props here. nc.threshold = nc_dict['threshold'] - nc.weight[0] = nc_dict['A_weight'] * \ - np.exp(-(d**2) / (nc_dict['lamtha']**2)) - nc.delay = nc_dict['A_delay'] / \ - (np.exp(-(d**2) / (nc_dict['lamtha']**2))) + nc.weight[0], nc.delay = _get_gaussian_connection( + nc_dict['pos_src'], self.pos, nc_dict) return nc - # pardistance function requires pre position, since it is - # calculated on POST cell - def _pardistance(self, pos_pre): - dx = self.pos[0] - pos_pre[0] - dy = self.pos[1] - pos_pre[1] - return np.sqrt(dx**2 + dy**2) - def plot_morphology(self, ax=None, cell_types=None, show=True): """Plot the cell morphology. diff --git a/hnn_core/network.py b/hnn_core/network.py index cb6ed0634..148ed1507 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -17,7 +17,7 @@ from .cells_default import pyramidal, basket from .cell_response import CellResponse from .params import _long_name, _short_name -from .viz import plot_cells, plot_connectivity_matrix +from .viz import plot_cells from .externals.mne import _validate_type, _check_option @@ -1268,25 +1268,6 @@ def __repr__(self): return entr - def plot(self, ax=None, show=True): - """Plot connectivity matrix for instance of _Connectivity object. - - Parameters - ---------- - ax : instance of matplotlib Axes3D | None - An axis object from matplotlib. If None, - a new figure is created. - show : bool - If True, show the figure. - - Returns - ------- - fig : instance of matplotlib Figure - The matplotlib figure handle. - """ - - return plot_connectivity_matrix(self, ax=ax, show=show) - class _NetworkDrive(dict): """A class for containing the parameters of external drives diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index 2f234c64a..6c27db717 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -7,6 +7,7 @@ import hnn_core from hnn_core import read_params, default_network from hnn_core.viz import plot_cells, plot_dipole, plot_psd, plot_tfr_morlet +from hnn_core.viz import plot_connectivity_matrix, plot_cell_connectivity from hnn_core.dipole import simulate_dipole matplotlib.use('agg') @@ -24,6 +25,31 @@ def test_network_visualization(): ax = net.cell_types['L2_pyramidal'].plot_morphology() assert len(ax.lines) == 8 + conn_idx = 0 + plot_connectivity_matrix(net, conn_idx, show=False) + with pytest.raises(TypeError, match='net must be an instance of'): + plot_connectivity_matrix('blah', conn_idx, show_weight=False) + + with pytest.raises(TypeError, match='conn_idx must be an instance of'): + plot_connectivity_matrix(net, 'blah', show_weight=False) + + with pytest.raises(TypeError, match='show_weight must be an instance of'): + plot_connectivity_matrix(net, conn_idx, show_weight='blah') + + src_gid = 5 + plot_cell_connectivity(net, conn_idx, src_gid, show=False) + with pytest.raises(TypeError, match='net must be an instance of'): + plot_cell_connectivity('blah', conn_idx, src_gid=src_gid) + + with pytest.raises(TypeError, match='conn_idx must be an instance of'): + plot_cell_connectivity(net, 'blah', src_gid) + + with pytest.raises(TypeError, match='src_gid must be an instance of'): + plot_cell_connectivity(net, conn_idx, src_gid='blah') + + with pytest.raises(ValueError, match='src_gid not in the'): + plot_cell_connectivity(net, conn_idx, src_gid=-1) + def test_dipole_visualization(): """Test dipole visualisations.""" diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 520590965..fe1e15c07 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -5,6 +5,7 @@ import numpy as np from itertools import cycle +from .externals.mne import _validate_type def _get_plot_data(dpl, layer, tmin, tmax): @@ -569,37 +570,87 @@ def plot_cell_morphology(cell, ax, show=True): return ax -def plot_connectivity_matrix(conn, ax=None, show=True): - """Plot connectivity matrix for instance of _Connectivity object. +def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, + colorbar=True, colormap='Greys', + show=True): + """Plot connectivity matrix with color bar for synaptic weights Parameters ---------- - conn : Instance of _Connectivity object - The _Connectivity object + net : Instance of Network object + The Network object + conn_idx : int + Index of connection to be visualized + from `net.connectivity` + ax : instance of Axes3D + Matplotlib 3D axis + show_weight : bool + If True, visualize connectivity weights as gradient. + If False, all weights set to constant value. + colormap : str + The name of a matplotlib colormap. Default: 'Greys' + colorbar : bool + If True (default), adjust figure to include colorbar. + show : bool + If True, show the plot Returns ------- fig : instance of matplotlib Figure The matplotlib figure handle. - """ import matplotlib.pyplot as plt - from.network import _Connectivity + from matplotlib.ticker import ScalarFormatter + from .network import Network + from .cell import _get_gaussian_connection - if not isinstance(conn, _Connectivity): - raise TypeError('conn must be instance of _Connectivity') + _validate_type(net, Network, 'net', 'Network') + _validate_type(conn_idx, int, 'conn_idx', 'int') + _validate_type(show_weight, bool, 'show_weight', 'bool') if ax is None: _, ax = plt.subplots(1, 1) + # Load objects for distance calculation + conn = net.connectivity[conn_idx] + nc_dict = conn['nc_dict'] + src_type = conn['src_type'] + target_type = conn['target_type'] + src_type_pos = net.pos_dict[src_type] + target_type_pos = net.pos_dict[target_type] + src_range = np.array(conn['src_range']) target_range = np.array(conn['target_range']) connectivity_matrix = np.zeros((len(src_range), len(target_range))) + for src_gid, target_src_pair in conn['gid_pairs'].items(): src_idx = np.where(src_range == src_gid)[0][0] - target_indeces = np.in1d(target_range, target_src_pair) - connectivity_matrix[src_idx, :] = target_indeces + target_indeces = np.where(np.in1d(target_range, target_src_pair))[0] + for target_idx in target_indeces: + src_pos = src_type_pos[src_idx] + target_pos = target_type_pos[target_idx] + + # Identical calculation used in Cell.par_connect_from_src() + if show_weight: + weight, _ = _get_gaussian_connection( + src_pos, target_pos, nc_dict) + else: + weight = 1.0 + + connectivity_matrix[src_idx, target_idx] = weight + + im = ax.imshow(connectivity_matrix, cmap=colormap, interpolation='none') + + ax.set_xlabel('Time (ms)') + ax.set_ylabel('Frequency (Hz)') + + if colorbar: + fig = ax.get_figure() + xfmt = ScalarFormatter() + xfmt.set_powerlimits((-2, 2)) + cbar = fig.colorbar(im, ax=ax, format=xfmt) + cbar.ax.yaxis.set_ticks_position('right') + cbar.ax.set_ylabel('Weight', rotation=-90, va="bottom") - ax.imshow(connectivity_matrix, cmap='Greys', interpolation='none') ax.set_xlabel(f"{conn['target_type']} target gids " f"({target_range[0]}-{target_range[-1]})") ax.set_xticklabels(list()) @@ -609,5 +660,115 @@ def plot_connectivity_matrix(conn, ax=None, show=True): ax.set_title(f"{conn['src_type']} -> {conn['target_type']} " f"({conn['loc']}, {conn['receptor']})") + plt.tight_layout() plt_show(show) return ax.get_figure() + + +def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, + colormap='viridis', show=True): + """Plot synaptic weight of connections originating from src_gid. + + Parameters + ---------- + net : Instance of Network object + The Network object + conn_idx : int + Index of connection to be visualized + from `net.connectivity` + src_gid : int + Each cell in a network is uniquely identified by it's "global ID": GID. + ax : instance of Axes3D + Matplotlib 3D axis + colormap : str + The name of a matplotlib colormap. Default: 'viridis' + colorbar : bool + If True (default), adjust figure to include colorbar. + show : bool + If True, show the plot + + Returns + ------- + fig : instance of matplotlib Figure + The matplotlib figure handle. + + Notes + ----- + Target cells will be determined by the connection class given by + net.connectivity[conn_idx]. + If the target cell is not connected to src_gid, it will appear as a + smaller black circle. + src_gid is plotted as a red circle. src_gid will not be plotted if + the connection corresponds to a drive, ex: poisson, bursty, etc. + + """ + import matplotlib.pyplot as plt + from .network import Network + from .cell import _get_gaussian_connection + from matplotlib.ticker import ScalarFormatter + + _validate_type(net, Network, 'net', 'Network') + _validate_type(conn_idx, int, 'conn_idx', 'int') + _validate_type(src_gid, int, 'src_gid', 'int') + if ax is None: + _, ax = plt.subplots(1, 1) + + # Load objects for distance calculation + conn = net.connectivity[conn_idx] + nc_dict = conn['nc_dict'] + src_type = conn['src_type'] + target_type = conn['target_type'] + src_type_pos = net.pos_dict[src_type] + target_type_pos = net.pos_dict[target_type] + + src_range = np.array(conn['src_range']) + if src_gid not in src_range: + raise ValueError(f'src_gid not in the src type range of {src_type} ' + f'gids. Valid gids include {src_range[0]} -> ' + f'{src_range[-1]}') + + target_range = np.array(conn['target_range']) + + # Extract indeces to get position in network + # Index in gid range aligns with net.pos_dict + target_src_pair = conn['gid_pairs'][src_gid] + target_indeces = np.where(np.in1d(target_range, target_src_pair))[0] + + src_idx = np.where(src_range == src_gid)[0][0] + src_pos = src_type_pos[src_idx] + + # Aggregate positions and weight of each connected target + weights, target_x_pos, target_y_pos = list(), list(), list() + for target_idx in target_indeces: + target_pos = target_type_pos[target_idx] + target_x_pos.append(target_pos[0]) + target_y_pos.append(target_pos[1]) + weight, _ = _get_gaussian_connection(src_pos, target_pos, nc_dict) + weights.append(weight) + + im = ax.scatter(target_x_pos, target_y_pos, c=weights, s=50, cmap=colormap) + + # Gather positions of all gids in target_type. + x_pos = [target_type_pos[idx][0] for idx in range(len(target_type_pos))] + y_pos = [target_type_pos[idx][1] for idx in range(len(target_type_pos))] + ax.scatter(x_pos, y_pos, color='k', marker='x', zorder=-1, s=20) + + # Only plot src_gid if proper cell type. + if src_type in net.cell_types: + ax.scatter(src_pos[0], src_pos[1], marker='s', color='red', s=150) + ax.set_ylabel('Y Position') + ax.set_xlabel('X Position') + ax.set_title(f"{conn['src_type']}-> {conn['target_type']}" + f" ({conn['loc']}, {conn['receptor']})") + + if colorbar: + fig = ax.get_figure() + xfmt = ScalarFormatter() + xfmt.set_powerlimits((-2, 2)) + cbar = fig.colorbar(im, ax=ax, format=xfmt) + cbar.ax.yaxis.set_ticks_position('right') + cbar.ax.set_ylabel('Weight', rotation=-90, va="bottom") + + plt.tight_layout() + plt_show(show) + return ax.get_figure(), ax