In [1]:
import os
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import gcsfs
import pyarrow.feather as feather
from motif_utils import *
import netsci.metrics.motifs as nsm
import json

## Loading the graph

In [2]:
DATASET = 'fafb_783'
REGION = 'central_complex'
NETWORK = f'{DATASET}_{REGION}'

SYN_LOCAL_PATH = f'data/{NETWORK}/synapses.feather'
GCS_PATH = f'gs://sjcabs_2025_data/fafb/{REGION}/{DATASET}_{REGION}_synapses.feather'
META_PATH = f'gs://sjcabs_2025_data/fafb/{REGION}/{DATASET}_{REGION}_meta.feather'

LOAD_FLYVIS = True
NETWORK = 'e1'

# FLY_VIS_PATH = f'data/flyvis_data/flyvis_full_synapse_df.csv'
FLY_VIS_PATH = f'data/flyvis_data/{NETWORK}/synapses.csv'


if not LOAD_FLYVIS:
    print('Network Name:', NETWORK)

    if os.path.exists(SYN_LOCAL_PATH):
        print(f'Loading from local: {SYN_LOCAL_PATH}')
        synapses_df = pd.read_feather(SYN_LOCAL_PATH)
    else:
        print(f'Loading from GCS: {GCS_PATH}')
        import gcsfs
        gcs = gcsfs.GCSFileSystem(token='google_default')
        import pyarrow.feather as feather
        with gcs.open(GCS_PATH.replace('gs://', ''), 'rb') as f:
            synapses_df = feather.read_feather(f)
        with gcs.open(META_PATH.replace('gs://', ''), 'rb') as f:
            meta_df = feather.read_feather(f)

    # Add edges (multiple synapses between same neurons become edge weight)
    edge_counts = synapses_df.groupby(['pre', 'post']).size().reset_index(name='weight')

else:
    print(f'Loading from FlyVis data: {FLY_VIS_PATH}')
    synapses_df = pd.read_csv(FLY_VIS_PATH)
    print(f'  Columns: {list(synapses_df.columns)}')
    assert not synapses_df.duplicated(subset=['source_index', 'target_index']).any(), "Duplicates found! You actually DO need groupby."
    edge_counts = synapses_df.rename(columns={
        'source_index': 'pre',
        'target_index': 'post',
        'n_syn': 'weight'
    })


print(f'✓ Loaded {len(synapses_df):,} synapses')
synapses_df.head()

Loading from FlyVis data: data/flyvis_data/e1/synapses.csv
  Columns: ['du', 'dv', 'n_syn', 'n_syn_certainty', 'sign', 'source_index', 'source_type', 'source_u', 'source_v', 'target_index', 'target_type', 'target_u', 'target_v']
✓ Loaded 8,174 synapses


Unnamed: 0,du,dv,n_syn,n_syn_certainty,sign,source_index,source_type,source_u,source_v,target_index,target_type,target_u,target_v
0,0,0,40.0,5.859477,-1.0,0,R1,-1,0,56,L1,-1,0
1,0,0,40.0,5.859477,-1.0,1,R1,-1,1,57,L1,-1,1
2,0,0,40.0,5.859477,-1.0,2,R1,0,-1,58,L1,0,-1
3,0,0,40.0,5.859477,-1.0,3,R1,0,0,59,L1,0,0
4,0,0,40.0,5.859477,-1.0,4,R1,0,1,60,L1,0,1


In [7]:
print(f'✓ Loaded {len(synapses_df):,} synapses')
G_connectome = nx.DiGraph()

for _, row in tqdm(edge_counts.iterrows()):
    if (row['pre'] != row['post']):  # Exclude self-connections
        G_connectome.add_edge(row['pre'], row['post'], weight=int(row['weight']))

print(f'✓ Built connectome graph')
print(f'  Nodes: {G_connectome.number_of_nodes():,}')
print(f'  Edges: {G_connectome.number_of_edges():,}')
print(f'  Density: {nx.density(G_connectome):.4f}')

mat_sparse = nx.to_scipy_sparse_array(G_connectome)
bin_mat_sparse = mat_sparse.copy()
bin_mat_sparse[bin_mat_sparse > 1] = 1
# bin_mat_sparse.setdiag(0)

# Create mapping between node IDs and matrix indices
node_list = list(G_connectome.nodes())
node_id_to_index = {node_id: idx for idx, node_id in enumerate(node_list)}
index_to_node_id = {idx: node_id for idx, node_id in enumerate(node_list)}

✓ Loaded 8,174 synapses


8174it [00:00, 46015.29it/s]

✓ Built connectome graph
  Nodes: 437
  Edges: 8,054
  Density: 0.0423





## Basic validation

In [4]:
out_degree = G_connectome.out_degree(index_to_node_id[1])
out_degree_mat = bin_mat_sparse[1, :].sum()
print(out_degree, out_degree_mat)

# for node in tqdm(G_connectome.nodes):
#     assert G_connectome.out_degree(node) == bin_mat_sparse[node_id_to_index[node], :].sum()
#     assert G_connectome.in_degree(node) == bin_mat_sparse[:, node_id_to_index[node]].sum()
#     break

9 9


In [8]:
def stats(conn_mat: np.ndarray):
    print('shape:', conn_mat.shape)
    print(f'total connections: {np.sum(conn_mat)}')
    print('total number of elements in the connectivity matrix (N^2):', conn_mat.size)
    print(f"Non-zero elements: {np.count_nonzero(conn_mat)}")
    print(f"Percentage of non-zero elements: {np.count_nonzero(conn_mat) / conn_mat.size * 100:.2f}%")
    max_post_idx, max_pre_idx = np.unravel_index(conn_mat.argmax(), conn_mat.shape)
    print(f'Max synapses between a single pair of neurons: {conn_mat[max_post_idx, max_pre_idx]} (from {max_pre_idx} to {max_post_idx})')

stats(mat_sparse.toarray())
print()
stats(bin_mat_sparse.toarray())

shape: (437, 437)
total connections: 36666
total number of elements in the connectivity matrix (N^2): 190969
Non-zero elements: 7631
Percentage of non-zero elements: 4.00%
Max synapses between a single pair of neurons: 144 (from 183 to 14)

shape: (437, 437)
total connections: 7631
total number of elements in the connectivity matrix (N^2): 190969
Non-zero elements: 7631
Percentage of non-zero elements: 4.00%
Max synapses between a single pair of neurons: 1 (from 1 to 0)


## Extraction

In [9]:
netsci_motif_keys = [12, 36, 6, 38, 14, 74, 98, 78, 102, 46, 108, 110, 238]

if LOAD_FLYVIS:
    output_dir = f'data/flyvis_data/{NETWORK}/motifs'
    os.makedirs(output_dir, exist_ok=True)
else:
    output_dir = f'data/{NETWORK}/motifs'
os.makedirs(output_dir, exist_ok=True)
fsl_full_path = f'{output_dir}/participation_nodes.h5'
ex_data_path = f'{output_dir}/ex_data.h5'


if os.path.exists(f'{output_dir}/binary_fsl.json'):
    print('Loading pre-calculated motifs...')
    network_fsl = json.load(open(f'{output_dir}/binary_fsl.json'))
    network_fsl = {int(k): v for k, v in network_fsl.items()}
else:
    print('Calculating motifs...')
    n_reals, participating_nodes = nsm.motifs(bin_mat_sparse.toarray(), algorithm='louzoun', participation=True)
    n_reals = n_reals[3:]
    participating_nodes = participating_nodes[3:]

    network_fsl = {netsci_motif_keys[i]: amount for (i, amount) in enumerate(n_reals)}
    fsl_fully_mapped = {netsci_motif_keys[i]: nodes for (i, nodes) in enumerate(participating_nodes)}

    with open(f'{output_dir}/binary_fsl.json', 'w') as f:
        json.dump({k: int(v) for k, v in network_fsl.items()}, f)

    save_motif_participation_nodes_h5(fsl_fully_mapped, fsl_full_path)

Calculating motifs...
Data saved to data/flyvis_data/e1/motifs/participation_nodes.h5


In [None]:
motif_ids = sorted(list(triplets_names.keys()))
motifs = {motif_id: create_base_motif(motif_id) for motif_id in motif_ids}
for motif_id in motif_ids:
    motifs[motif_id].n_real = network_fsl.get(motif_id, 0)

fsl_nodes = {}
neuron_names = node_list

In [11]:
def convert_part_nodes_to_sub_graphs(motif, participating_nodes: list):
    fsl_fully_mapped = []
    for i, sub_graphs in enumerate(tqdm(participating_nodes)):
        nodes = list(sub_graphs)
        graph_nodes = [neuron_names[n] for n in nodes]
        sub_graph_edges = nx.induced_subgraph(G_connectome, graph_nodes).edges
        fsl_fully_mapped.append(tuple(sub_graph_edges))
    motif.sub_graphs = fsl_fully_mapped


def populate_motif(motif: Motif, participating_nodes: list):
    print('converting nodes to sub graphs...')
    convert_part_nodes_to_sub_graphs(motif, participating_nodes)
    print('sorting node roles...')
    motif.node_roles = sort_node_roles_in_sub_graph(appearances=motif.sub_graphs, neuron_names=neuron_names, motif=motif)


def get_motif_roles_freq_csv(motif):
    all_nodes_data = []
    for node in tqdm(neuron_names):
        node_data_dict = {}
        for role in motif.node_roles.keys():
            node_data_dict[f'{motif.id}_{role}'] = motif.node_roles[role].get(node, 0)
        all_nodes_data.append(node_data_dict)

    node_roles_df = pd.DataFrame(all_nodes_data)
    node_roles_df.to_csv(f'{output_dir}/motif_{motif.id}_roles_freq.csv', index=True)


def get_motif_subgraphs_csv(motif):
    all_subgraphs_data = []
    for sub_graph in tqdm(motif.sub_graphs):
        all_subgraphs_data.append(get_sub_graph_mapping_to_motif(sub_graph, motif.role_pattern))

    sub_graphs_df = pd.DataFrame(all_subgraphs_data)
    sub_graphs_df.to_csv(f'{output_dir}/motif_{motif.id}_subgraphs.csv', index=True)

In [14]:
for motif_id in [38]: #[6,78,46,98,78,38]:
    fsl_nodes[motif_id] = load_motif_participation_nodes_h5(fsl_full_path, motif_id)

    assert len(fsl_nodes[motif_id]) == motifs[motif_id].n_real
    print(motifs[motif_id].name)
    print(len(fsl_nodes[motif_id]))

    motif = motifs[motif_id]
    loaded_sg = None
    if os.path.isfile(ex_data_path):
        print('Loading existing motif data...')
        loaded_sg, loaded_roles = load_motif_data(motif.id, ex_data_path)
        if loaded_sg:
            motif.node_roles = loaded_roles
            motif.sub_graphs = loaded_sg
    if not loaded_sg or not os.path.isfile(ex_data_path):
        populate_motif(motif, fsl_nodes[motif_id])
        populate_motif_data(motif.sub_graphs, motif.node_roles, motif.id, ex_data_path)

    get_motif_roles_freq_csv(motif)
    get_motif_subgraphs_csv(motif)

feed forward
9782
converting nodes to sub graphs...


100%|██████████| 9782/9782 [00:00<00:00, 31616.46it/s]

sorting node roles...





Exception: The sub graph is not isomorphic to the motif

In [None]:
motif.role_pattern

In [None]:
motif.sub_graphs[0]


In [None]:
fsl_fully_mapped[38][0]

In [None]:
triplet = [node_id_to_index[node] for node in fsl_fully_mapped[38][0]]
triplet


In [None]:
# Define your nodes of interest
sub = G_connectome.subgraph(triplet)

print(f"Nodes in subgraph: {sub.nodes()}")
print("Edges in subgraph:")
if sub.number_of_edges() == 0:
    print("  No edges found between these specific nodes.")
else:
    for u, v, data in sub.edges(data=True):
        print(f"  {u} -> {v}, weight: {data.get('weight')}")

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

# 1. Create the subgraph from your existing graph
# Replace [node_id_1, node_id_2, node_id_3] with the actual IDs you want to see
subgraph = G_connectome.subgraph(triplet)

# 2. Plotting
plt.figure(figsize=(6, 4))
pos = nx.spring_layout(subgraph, seed=42)  # Layout for consistent positioning

# Draw the nodes and edges
nx.draw_networkx_nodes(subgraph, pos, node_size=700, node_color='skyblue')
nx.draw_networkx_edges(subgraph, pos, width=2, arrowsize=20)
nx.draw_networkx_labels(subgraph, pos, font_size=12, font_family="sans-serif")

# Optional: Draw edge labels (weights) if you want to see them
edge_labels = nx.get_edge_attributes(subgraph, 'weight')
nx.draw_networkx_edge_labels(subgraph, pos)

plt.axis('off')
plt.tight_layout()
plt.show()