Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add function to visualize connectivity weights #339

Merged
merged 23 commits into from Jun 5, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 14 additions & 3 deletions examples/plot_connectivity.py
Expand Up @@ -8,7 +8,7 @@

# Author: Nick Tolley <nicholas_tolley@brown.edu>

# sphinx_gallery_thumbnail_number = 5
# sphinx_gallery_thumbnail_number = 1

import os.path as op

Expand Down Expand Up @@ -38,9 +38,19 @@
# Instantiating the network comes with a predefined set of connections that
# reflect the canonical neocortical microcircuit. ``net.connectivity``
# is a list of dictionaries which detail every cell-cell, and drive-cell
# connection.
# connection. The weights of these connections can be visualized with
# :func:`~hnn_core.viz.plot_connectivity_weights` as well as
# :func:`~hnn_core.viz.plot_cell_connectivity`
from hnn_core.viz import plot_connectivity_weights, plot_cell_connectivity
print(len(net_erp.connectivity))
print(net_erp.connectivity[0:2])

conn_idx = 15
print(net_erp.connectivity[conn_idx])
plot_connectivity_weights(net_erp, conn_idx)

conn_idx, gid_idx = 20, 11
src_gid = net_erp.connectivity[conn_idx]['src_range'][gid_idx]
fig, ax = plot_cell_connectivity(net_erp, conn_idx, src_gid)
ntolley marked this conversation as resolved.
Show resolved Hide resolved

###############################################################################
# Data recorded during simulations are stored under
Expand Down Expand Up @@ -118,6 +128,7 @@
net_sparse.connectivity[-2].plot()
net_sparse.connectivity[-1].plot()


###############################################################################
# Using the sparse connectivity pattern produced a lot more spiking in
# the L5 pyramidal cells. Nevertheless there appears to be some rhythmicity
Expand Down
80 changes: 66 additions & 14 deletions hnn_core/cell.py
Expand Up @@ -25,6 +25,68 @@ def _get_cos_theta(p_secs, sec_name_apical):
return cos_thetas


def _calculate_gaussian(x_val, height, lamtha):
ntolley marked this conversation as resolved.
Show resolved Hide resolved
"""Return height of gaussian at x_val.

Parameters
----------
x_val : float
Value on x-axis to query height of gaussian curve.
height : float
Height of the gaussian curve at zero.
lamtha : float
Space constant.

Returns
-------
x_height : float
Height of gaussian at x_val.

Notes
-----
Gaussian curve is centered at zero and has a fixed peak height
such the _calculate_gaussian(0, lamtha) returns 1 for all lamtha.
"""
x_height = height * np.exp(-(x_val**2) / (lamtha**2))

return x_height


def _get_gaussian_connection(src_pos, target_pos, nc_dict):
"""Calculate distance dependent connection properties.

Parameters
----------
src_pos : float
Position of source cell.
target_pos : float
Position of target cell.
nc_dict : dict
Dictionary with keys: pos_src, A_weight, A_delay, lamtha
Defines the connection parameters

Returns
-------
weight : float
Weight of the synaptic connection.
delay : float
Delay of synaptic connection.

Notes
-----
Distance in xy plane is used for gaussian decay.
"""
x_dist = target_pos[0] - src_pos[0]
y_dist = target_pos[1] - src_pos[1]
cell_dist = np.sqrt(x_dist**2 + y_dist**2)
ntolley marked this conversation as resolved.
Show resolved Hide resolved

weight = _calculate_gaussian(
cell_dist, nc_dict['A_weight'], nc_dict['lamtha'])
delay = nc_dict['A_delay'] / _calculate_gaussian(
cell_dist, 1, nc_dict['lamtha'])
return weight, delay


class _ArtificialCell:
"""The ArtificialCell class for initializing a NEURON feed source.

Expand Down Expand Up @@ -477,24 +539,14 @@ def parconnect_from_src(self, gid_presyn, nc_dict, postsyn):
from .network_builder import _PC

nc = _PC.gid_connect(gid_presyn, postsyn)
# calculate distance between cell positions with pardistance()
d = self._pardistance(nc_dict['pos_src'])
# set props here

# set props here.
nc.threshold = nc_dict['threshold']
nc.weight[0] = nc_dict['A_weight'] * \
np.exp(-(d**2) / (nc_dict['lamtha']**2))
nc.delay = nc_dict['A_delay'] / \
(np.exp(-(d**2) / (nc_dict['lamtha']**2)))
nc.weight[0], nc.delay = _get_gaussian_connection(
nc_dict['pos_src'], self.pos, nc_dict)

return nc

# pardistance function requires pre position, since it is
# calculated on POST cell
def _pardistance(self, pos_pre):
dx = self.pos[0] - pos_pre[0]
dy = self.pos[1] - pos_pre[1]
return np.sqrt(dx**2 + dy**2)

def plot_morphology(self, ax=None, cell_types=None, show=True):
"""Plot the cell morphology.

Expand Down
148 changes: 146 additions & 2 deletions hnn_core/viz.py
Expand Up @@ -5,6 +5,7 @@

import numpy as np
from itertools import cycle
from .externals.mne import _validate_type


def _get_plot_data(dpl, layer, tmin, tmax):
Expand Down Expand Up @@ -569,13 +570,84 @@ def plot_cell_morphology(cell, ax, show=True):
return ax


def plot_connectivity_weights(net, conn_idx, ax=None, show=True):
Copy link
Collaborator

@jasmainak jasmainak May 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we still have the plot_connectivity_matrix function? Shouldn't it be absorbed into this one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was realizing it is pretty superfluous while writing this. I'll add a flag that skips the weight computation and just plots black/white squares.

Then this can be renamed plot_connectivity_matrix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might want to ignore this comment now since I understand the situation better now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is still relevant after getting the cell-specific plotting function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly I am still leaning towards my previous comment in this thread. While the weight visualization here is obviously not as intuitive as cell-specific plotting, it seems weird not to include the functionality.

I'm planning to remove the "black and white" matrix function, and rename this function to plot_connectivity_matrix. Then I can set a show_weights=False flag for the default behavior.

"""Plot connectivity matrix with color bar for synaptic weights

Parameters
----------
net : Instance of Network object
The Network object
conn_idx : int
Index of connection to be visualized
from `net.connectivity`
ax : instance of Axes3D
Matplotlib 3D axis
show : bool
If True, show the plot

Returns
-------
fig : instance of matplotlib Figure
The matplotlib figure handle.
"""
import matplotlib.pyplot as plt
from .network import Network
from .cell import _get_gaussian_connection

_validate_type(net, Network, 'net', 'Network')
_validate_type(conn_idx, int, 'conn_idx', 'int')
if ax is None:
_, ax = plt.subplots(1, 1)

# Load objects for distance calculation
conn = net.connectivity[conn_idx]
nc_dict = conn['nc_dict']
src_type = conn['src_type']
target_type = conn['target_type']
src_type_pos = net.pos_dict[src_type]
target_type_pos = net.pos_dict[target_type]

src_range = np.array(conn['src_range'])
target_range = np.array(conn['target_range'])
connectivity_matrix = np.zeros((len(src_range), len(target_range)))

for src_gid, target_src_pair in conn['gid_pairs'].items():
src_idx = np.where(src_range == src_gid)[0][0]
target_indeces = np.where(np.in1d(target_range, target_src_pair))[0]
for target_idx in target_indeces:
src_pos = src_type_pos[src_idx]
target_pos = target_type_pos[target_idx]

# Identical calculation used in Cell.par_connect_from_src()
weight, _ = _get_gaussian_connection(src_pos, target_pos, nc_dict)

connectivity_matrix[src_idx, target_idx] = weight

ax.imshow(connectivity_matrix, cmap='Greys', interpolation='none')
ax.set_xlabel(f"{conn['target_type']} target gids "
f"({target_range[0]}-{target_range[-1]})")
ax.set_xticklabels(list())
ax.set_ylabel(f"{conn['src_type']} source gids "
f"({src_range[0]}-{src_range[-1]})")
ax.set_yticklabels(list())
ax.set_title(f"{conn['src_type']} -> {conn['target_type']} "
f"({conn['loc']}, {conn['receptor']})")

plt_show(show)
return ax.get_figure()


def plot_connectivity_matrix(conn, ax=None, show=True):
"""Plot connectivity matrix for instance of _Connectivity object.

Parameters
----------
conn : Instance of _Connectivity object
The _Connectivity object
ax : instance of Axes3D
Matplotlib 3D axis
show : bool
If True, show the plot

Returns
-------
Expand All @@ -586,8 +658,7 @@ def plot_connectivity_matrix(conn, ax=None, show=True):
import matplotlib.pyplot as plt
from.network import _Connectivity

if not isinstance(conn, _Connectivity):
raise TypeError('conn must be instance of _Connectivity')
_validate_type(conn, _Connectivity, 'conn', '_Connectivity')
if ax is None:
_, ax = plt.subplots(1, 1)

Expand All @@ -611,3 +682,76 @@ def plot_connectivity_matrix(conn, ax=None, show=True):

plt_show(show)
return ax.get_figure()


def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, show=True):
"""Plot synaptic weight of connections from a src_gid.

Parameters
----------
net : Instance of Network object
The Network object
conn_idx : int
Index of connection to be visualized
from `net.connectivity`
src_gid : int
Each cell in a network is uniquely identified by it's "global ID": GID.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm in doubt what this function does. Could you be more explicit about the first two arguments? I don't understand which int's to enter. Loving the output though!!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely will add a better docstring. conn_idx is awkward to work with at the moment as it is really just the index of the connection you'd like to visualize in net.connectivity[conn_idx]. Unfortunately the alternative of passing the _Connectivity object directly isn't much better.

I think this awkwardness will be resolved when I implement a utility function discussed with @jasmainak. It will look like

net.pick_connectivity(src_type='L5_pyramidal', receptor='ampa',...)

and return the indeces of all connections that match the provided arguments.

ntolley marked this conversation as resolved.
Show resolved Hide resolved
ax : instance of Axes3D
Matplotlib 3D axis
show : bool
If True, show the plot

Returns
-------
fig : instance of matplotlib Figure
The matplotlib figure handle.
im : Instance of matplotlib AxesImage
The matplotlib AxesImage handle.
"""
import matplotlib.pyplot as plt
from .network import Network
from .cell import _get_gaussian_connection
Comment on lines +706 to +707
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't nest this if possible ... not sure if you have a circular import issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked, unfortunately the nesting is required here as well.


_validate_type(net, Network, 'net', 'Network')
_validate_type(conn_idx, int, 'conn_idx', 'int')
_validate_type(src_gid, int, 'src_gid', 'int')
if ax is None:
_, ax = plt.subplots(1, 1)

# Load objects for distance calculation
conn = net.connectivity[conn_idx]
nc_dict = conn['nc_dict']
src_type = conn['src_type']
target_type = conn['target_type']
src_type_pos = net.pos_dict[src_type]
target_type_pos = net.pos_dict[target_type]

src_range = np.array(conn['src_range'])
target_range = np.array(conn['target_range'])

# Extract indeces to get position in network
# Index in gid range aligns with net.pos_dict
target_src_pair = conn['gid_pairs'][src_gid]
target_indeces = np.where(np.in1d(target_range, target_src_pair))[0]

src_idx = np.where(src_range == src_gid)[0][0]
src_pos = src_type_pos[src_idx]

weights, x_pos, y_pos = list(), list(), list()
for target_idx in target_indeces:
target_pos = target_type_pos[target_idx]
x_pos.append(target_pos[0])
y_pos.append(target_pos[1])
weight, _ = _get_gaussian_connection(src_pos, target_pos, nc_dict)
weights.append(weight)

ax.scatter(x_pos, y_pos, c=weights)

ax.scatter(src_pos[0], src_pos[1], color='red', s=100)
ax.set_ylabel('Y Position')
ax.set_xlabel('X Position')
ax.set_title(f"{conn['src_type']} (gid={src_gid}) -> {conn['target_type']}"
f" ({conn['loc']}, {conn['receptor']})")

plt_show(show)
return ax.get_figure(), ax