# Extract and display a connectivity matrix from a SONATA circuit

Copyright (c) 2025 Open Brain Institute

Authors: Christoph Pokorny

Last modified: 08.2025

## Summary
This analysis extracts and visualizes a matrix of connection probabilities or #synapses per connection (mean/std/...), grouped by a selected neuron property (layer, m-type, ...).
For details, see the [README](README.md).

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import time

from bluepysnap import Circuit
from connectome_manipulator.connectome_comparison import connectivity
from datetime import datetime
from entitysdk import Client, ProjectContext, models
from ipywidgets import widgets, interact
from obi_auth import get_token
from obi_notebook import get_projects
from obi_notebook import get_entities

## Circuit selection and download

A SONATA circuit will be analyzed. To download the data there, we must select 
the circuit, then download it. 

Should the circuit of interest already be placed on the local file system, you can skip ahead to the section `Circuit analysis` below.

#### Project selection
As a first step we select one of the projects we have access to that the circuit is associated with. If the circuit of interest is part of the public OBI assets, any project can be selected.

In [None]:
token = get_token(environment="production", auth_mode="daf")
project_context = get_projects.get_projects(token)

#### Circuit selection

Next, we select the circuit. If you already know the unique identifier of the circuit of interest, paste it below into line 4 of the next cell.

Otherwise, a widget for circuit selection will be created that allows you to simply mark the circuit of interest.

In [None]:
client = Client(environment="production", project_context=project_context, token_manager=token)

# Optional: Download using unique ID
entity_ID = "<CIRCUIT-ID>"  # <<< FILL IN UNIQUE CIRCUIT ID HERE


if entity_ID != "<CIRCUIT-ID>":
    circuit_ids = [entity_ID]
else:
# Alternative: Select from a table of entities
    circuit_ids = []
    circuit_ids = get_entities.get_entities("circuit", token, circuit_ids,
                                            project_context=project_context,
                                            multi_select=False, exclude_scales=["single"],
                                            show_pages=True, page_size=12,
                                            default_scale="small", add_columns=["subject.name"])

#### Fetch circuit
The circuit is copied to the local system at the expected location.

In [None]:
# Fetch circuit
fetched = client.get_entity(entity_id=circuit_ids[0], entity_type=models.Circuit)
print(f"Circuit fetched: {fetched.name} (ID {fetched.id})\n")
print(f"#Neurons: {fetched.number_neurons}, #Synapses: {fetched.number_synapses}, #Connections: {fetched.number_connections}\n")
print(f"{fetched.description}\n")

# Download SONATA circuit files
asset = [asset for asset in fetched.assets if asset.label=="sonata_circuit"][0]
asset_dir = asset.path 
circuit_dir = "analysis_circuit_" + datetime.now().strftime('%Y-%m-%d_%H-%M-%S_%f')

t0 = time.time()
client.download_directory(
    entity_id=fetched.id,
    entity_type=models.Circuit,
    asset_id=asset.id,
    output_path=circuit_dir,
    max_concurrent=4,  # Parallel file download
)
t = time.time() - t0
print(f"Circuit files downloaded to '{os.path.join(circuit_dir, asset_dir)}' in {t:.1f}s")

## Circuit analysis

By default, the circuit from the download location is used. If a circuit from some other location us used, please modify the `circuit_config = ...` path below accordingly.

In [None]:
# Path to existing circuit config
circuit_config = os.path.join(circuit_dir, asset_dir, "circuit_config.json")

assert os.path.exists(circuit_config), f"ERROR: Circuit config '{os.path.split(circuit_config)[1]}' not found!"

Loading SONATA circuit. Selections of the edge population containing the synapses, as well as pre-/post-synaptic node sets defining groups of neurons are possible.

In [None]:
c = Circuit(circuit_config)
e_populations = c.edges.population_names
assert len(e_populations) > 0, "ERROR: No edge population found!"
src_popul_types = [c.edges[_pop].source.type for _pop in e_populations]
if "biophysical" in src_popul_types:  # Select (first) biophysical by default
    default_idx = src_popul_types.index("biophysical")
else:
    default_idx = 0
node_sets = list(c.node_sets.content.keys())
e_popul_wdgt = widgets.Dropdown(options=e_populations, description="Edge population:", value=e_populations[default_idx], style={"description_width": "auto"}, layout=widgets.Layout(width="max-content"))
pre_nset_wdgt = widgets.Dropdown(options=[None] + node_sets, description="Pre-synaptic node set:", style={"description_width": "auto"}, layout=widgets.Layout(width="max-content"))
post_nset_wdgt = widgets.Dropdown(options=[None] + node_sets, description="Post-synaptic node set:", style={"description_width": "auto"}, layout=widgets.Layout(width="max-content"))
display(e_popul_wdgt)
display(pre_nset_wdgt)
display(post_nset_wdgt)

## Grouping selection

The connectivity is grouped by one selected (categorical) neuron property, like layer, m-type, etc.

In [None]:
def get_props_categorical(nodes):
    is_cat = nodes.get([]).dtypes.apply(lambda _x: isinstance(_x, pd.CategoricalDtype))
    props = is_cat[is_cat].index.to_list()  # Select categorical properties
    for _p in ["region", "layer", "mtype", "etype", "synapse_class"]:
        if _p in nodes.property_names:
            props = props + [_p]  # Add certain properties, even if not categorical
    return np.unique(props)

def get_props_float(nodes):
    props = nodes.property_dtypes[nodes.property_dtypes == float]
    return props.index.values

node_props = np.intersect1d(get_props_categorical(c.edges[e_popul_wdgt.value].source),
                            get_props_categorical(c.edges[e_popul_wdgt.value].target))
float_props = np.intersect1d(get_props_float(c.edges[e_popul_wdgt.value].source),
                            get_props_float(c.edges[e_popul_wdgt.value].target))
selected_props = list(np.intersect1d(["x", "y", "z"], float_props))

groupby_wdgt = widgets.Dropdown(options=node_props, value="layer" if "layer" in node_props else node_props[0], description="Group-by:", style={"description_width": "auto"}, layout=widgets.Layout(width="max-content"))
distance_wdgt = widgets.IntSlider(value=100, min=10, max=1000, step=10, description="Max distance", readout=True, style={"description_width": "auto"})
dist_props_wdgt = widgets.SelectMultiple(rows=10, options=float_props, value=selected_props, description="Distance properties", style={"description_width": "auto"})
use_dist_wdgt = widgets.Checkbox(value=False, description="Use distance cutoff")

def display_fcn(use_dist_val):
    if use_dist_val:
        distance_wdgt.layout.visibility = "visible"
        dist_props_wdgt.layout.visibility = "visible"
    else:
        distance_wdgt.layout.visibility = "hidden"
        dist_props_wdgt.layout.visibility = "hidden"
i = widgets.interactive(display_fcn, use_dist_val=use_dist_wdgt)
display(groupby_wdgt)
display(i)
display(distance_wdgt)
display(dist_props_wdgt)


## Connectivity extraction

Extract the connectivity for all pairs of groups using functionality from [connectome-manipulator](https://github.com/openbraininstitute/connectome-manipulator).

In [None]:
max_dist_val, props_val = None, None
if use_dist_wdgt.value:
    max_dist_val = distance_wdgt.value
    props_val = list(dist_props_wdgt.value)
conn_dict = connectivity.compute(c, sel_src={"node_set": pre_nset_wdgt.value}, sel_dest={"node_set": post_nset_wdgt.value}, edges_popul_name=e_popul_wdgt.value, group_by=groupby_wdgt.value,
                                 max_distance=max_dist_val, props_for_distance=props_val)

## Interactive visualization

Interactive visualization of the connectivity matrix. The user can select what characteristic to display, such as connection probabilities or mean/max/min/SEM/std of #synapses per connection (SEM ... standard error of the mean). Also, empty groups (i.e., groups w/o any neurons or connections) can be excluded using a checkbox.

In [None]:
# Interactive plot function
def plot_fct(res_sel, empty_sel):
    def filter_conn_dict(conn_dict, res_sel):
        out_dict = {}
        p = conn_dict["conn_prob"]["data"]
        pre_sel = np.any(p > 0.0, axis=1)
        post_sel = np.any(p > 0.0, axis=0)
        out_dict[res_sel] = conn_dict[res_sel].copy()
        out_dict[res_sel]["data"] = out_dict[res_sel]["data"][pre_sel, :][:, post_sel]
        out_dict["common"] = conn_dict["common"].copy()
        out_dict["common"]["src_group_values"] = [_v for _v, _s in zip(out_dict["common"]["src_group_values"], pre_sel) if _s]
        out_dict["common"]["tgt_group_values"] = [_v for _v, _s in zip(out_dict["common"]["tgt_group_values"], post_sel) if _s]
        return out_dict
    
    if empty_sel:
        plot_dict = conn_dict
    else:
        plot_dict = filter_conn_dict(conn_dict, res_sel)

    if plot_dict[res_sel]["data"].size > 0:
        vmax = np.max(plot_dict[res_sel]["data"])
        if vmax == 0.0:
            vmax = 1.0
        connectivity.plot(plot_dict[res_sel], plot_dict["common"], vmin=0.0, vmax=vmax, group_by=groupby_wdgt.value)
    else:
        plt.figure()
        plt.text(0, 0, "Nothing to show...")
        plt.axis("off")
        plt.axis("tight")
        plt.show()

In [None]:
res_sel_opt = [(_v["unit"], _k) for _k, _v in sorted(conn_dict.items()) if "unit" in _v]
res_sel_wdgt = widgets.Dropdown(options=res_sel_opt, description="Display:", style={"description_width": "auto"}, layout=widgets.Layout(width="max-content"))
empty_sel_wdgt = widgets.Checkbox(value=True, description="Show empty groups", style={"description_width": "auto"}, layout=widgets.Layout(width="max-content"))
iplot = interact(plot_fct, res_sel=res_sel_wdgt, empty_sel=empty_sel_wdgt)