In [3]:
import pandas as pd
import numpy as np

In [4]:
connections = pd.read_csv(
    "../../new_data/connections.csv",
    dtype={"pre_root_id": "string", "post_root_id": "string", "syn_count": np.int32},
)

nc = pd.read_table(
    "../../new_data/neuron_annotations.tsv",
    dtype={
        "root_id": "string",
        "soma_x": np.float32,
        "soma_y": np.float32,
        "soma_z": np.float32,
        "cell_type": "string",
    },
)[["root_id", "pos_x", "pos_y", "pos_z", "soma_x", "soma_y", "soma_z", "cell_type"]]
# fill missing soma nans with pos
nc["soma_x"] = nc["soma_x"].fillna(nc["pos_x"])
nc["soma_y"] = nc["soma_y"].fillna(nc["pos_y"])
nc["soma_z"] = nc["soma_z"].fillna(nc["pos_z"])
nc = nc.drop(columns=["pos_x", "pos_y", "pos_z"])

  nc = pd.read_table(


In [3]:
def compute_total_synapse_length(connections, nc):

    df = connections.merge(
        nc,
        left_on="pre_root_id",
        right_on="root_id",
        suffixes=("", "_pre"),
    ).merge(
        nc,
        left_on="post_root_id",
        right_on="root_id",
        suffixes=("_pre", "_post"),
    )

    # Drop unneeded columns to free memory
    df = df.drop(columns=["root_id_pre", "root_id_post", "pre_root_id", "post_root_id"])

    # Vectorized distance calculation
    distances = np.sqrt(
        (df["soma_x_pre"] - df["soma_x_post"]) ** 2
        + (df["soma_y_pre"] - df["soma_y_post"]) ** 2
        + (df["soma_z_pre"] - df["soma_z_post"]) ** 2
    )

    # Multiply by synapse counts and sum
    total_length = np.sum(distances * df["syn_count"])

    return total_length

In [4]:
total_length = compute_total_synapse_length(connections, nc)

# Shuffling strategy 1: shuffle randomly, then reduce synapse count until happy

In [5]:
import numpy as np
import pandas as pd


def shuffle_post_root_id(connections):
    """
    Returns a copy of `connections` where the 'post_root_id' column
    is randomly shuffled, leaving 'pre_root_id' and 'syn_count' intact.
    """
    new_conns = connections.copy()
    # Shuffle post_root_id in place
    shuffled_posts = np.random.permutation(new_conns["post_root_id"].values)
    new_conns["post_root_id"] = shuffled_posts
    return new_conns


def compute_total_synapse_length(connections, nc):
    """
    Computes sum(distance * syn_count) by merging with `nc`.
    connections: DataFrame with columns [pre_root_id, post_root_id, syn_count]
    nc: DataFrame with columns [root_id, soma_x, soma_y, soma_z]
    Returns a float (total wiring length).
    """
    # Merge pre
    df = connections.merge(
        nc,
        left_on="pre_root_id",
        right_on="root_id",
        suffixes=("", "_pre"),
    ).merge(
        nc,
        left_on="post_root_id",
        right_on="root_id",
        suffixes=("_pre", "_post"),
    )

    # Drop unneeded columns
    df = df.drop(columns=["root_id_pre", "root_id_post", "pre_root_id", "post_root_id"])

    # Compute Euclidean distance
    dx = df["soma_x_pre"] - df["soma_x_post"]
    dy = df["soma_y_pre"] - df["soma_y_post"]
    dz = df["soma_z_pre"] - df["soma_z_post"]
    dist = np.sqrt(dx * dx + dy * dy + dz * dz)

    # Weighted sum
    total_length = np.sum(dist * df["syn_count"])
    return total_length


def match_wiring_length_with_syn_scale(
    connections,
    nc,
    real_length,
    scale_low=0.0,
    scale_high=2.0,
    max_iter=5,
    tolerance=0.01,
):
    """
    Iteratively search for a global scale factor s in [scale_low, scale_high]
    such that sum(distance * (syn_count*s)) ~ real_length.

    connections: DataFrame [pre_root_id, post_root_id, syn_count]
    nc         : DataFrame [root_id, soma_x, soma_y, soma_z]
    real_length: Target total wiring length
    scale_low  : initial lower bound on scale
    scale_high : initial upper bound on scale
    max_iter   : max number of binary search iterations
    tolerance  : fraction (or absolute) to decide "close enough"

    Returns the scaled DataFrame. The 'syn_count' column is overwritten
    with the final scaled values.
    """
    # Copy so we don't modify the original DataFrame
    conns_scaled = connections.copy()

    for i in range(max_iter):
        mid = 0.5 * (scale_low + scale_high)

        # Temporarily scale syn_count by mid
        conns_scaled["syn_count"] = connections["syn_count"] * mid
        length_est = compute_total_synapse_length(conns_scaled, nc)

        # Check how close we are
        ratio = length_est / real_length
        print(f"Iter={i}, scale={mid:.4f}, length={length_est:.2f}, ratio={ratio:.3f}")

        # If we're still below the real length, we need a bigger scale
        if length_est < real_length:
            scale_low = mid
        else:
            scale_high = mid

        # If ratio is within ~1 +/- tolerance, we can stop
        if abs(ratio - 1.0) < tolerance:
            break

    # Final scale
    final_scale = 0.5 * (scale_low + scale_high)
    conns_scaled["syn_count"] = connections["syn_count"] * final_scale

    return conns_scaled

In [6]:

np.random.seed(12345)
connections_shuffled = shuffle_post_root_id(connections)

# 3) Use a small binary search to find a global syn_count scale factor
#    that brings the total (dist * syn_count) close to the real_length
scaled_random = match_wiring_length_with_syn_scale(
    connections_shuffled,
    nc,
    total_length,
    scale_low=0.0,
    scale_high=2.0,
    max_iter=10,
    tolerance=0.01,
)

# 4) scaled_random now has syn_count scaled so that total wiring length
#    is in the same ballpark as the real connectome
final_length = compute_total_synapse_length(scaled_random, nc)
print("Final scaled random total length =", final_length)

Iter=0, scale=1.0000, length=3231021020915.70, ratio=2.443
Iter=1, scale=0.5000, length=1615510510457.85, ratio=1.222
Iter=2, scale=0.2500, length=807755255228.93, ratio=0.611
Iter=3, scale=0.3750, length=1211632882843.38, ratio=0.916
Iter=4, scale=0.4375, length=1413571696650.62, ratio=1.069
Iter=5, scale=0.4062, length=1312602289747.00, ratio=0.993
Final scaled random total length = 1363086993198.8096


In [42]:
scaled_random["syn_count"] = np.round(scaled_random["syn_count"]).astype(np.int32)
scaled_random.to_csv("../../new_data/connections_random.csv", index=False)

# Shuffling strategy 2: bin by distance, and shuffle within bins

In [None]:
import numpy as np
import pandas as pd
import time

from tqdm import tqdm

import os
from paths import PROJECT_ROOT


def shuffle_within_bin(group):
    """Shuffle entire rows within a bin"""
    if len(group) <= 1:
        return group

    # Create a shuffled copy of the group
    return group.sample(frac=1.0).reset_index(drop=True)


def create_length_preserving_random_network(
    connections, neurons, bins=10, tolerance=0.1
):
    print("Starting fully vectorized network randomization using pandas...")
    start_time = time.time()

    # Ensure numeric types
    connections = connections.copy()
    neurons = neurons.copy()

    connections["pre_root_id"] = connections["pre_root_id"].astype(int)
    connections["post_root_id"] = connections["post_root_id"].astype(int)
    connections["syn_count"] = connections["syn_count"].astype(int)
    neurons["root_id"] = neurons["root_id"].astype(int)

    # Identify retinal and decision neurons
    retinal_ids = set(
        neurons[neurons["cell_type"].isin(["R1-6", "R7", "R8"])]["root_id"]
    )
    decision_ids = set(
        neurons[neurons["cell_type"].isin(["KCapbp-m", "KCapbp-ap2", "KCapbp-ap1"])][
            "root_id"
        ]
    )

    print(
        f"Identified {len(retinal_ids)} retinal neurons and {len(decision_ids)} decision neurons"
    )

    # Create preserve mask
    preserve_mask = connections["pre_root_id"].isin(retinal_ids) | connections[
        "post_root_id"
    ].isin(decision_ids)

    preserved_connections = connections[preserve_mask].copy()
    randomizable_connections = connections[~preserve_mask].copy()

    print(
        f"Preserved {len(preserved_connections)} connections, will randomize {len(randomizable_connections)}"
    )

    # Join with neuron coordinates - single vectorized operation
    # Add source neuron coordinates
    pre_neurons = neurons[["root_id", "soma_x", "soma_y", "soma_z"]].copy()
    pre_neurons.columns = ["pre_root_id", "pre_x", "pre_y", "pre_z"]

    # Add target neuron coordinates
    post_neurons = neurons[["root_id", "soma_x", "soma_y", "soma_z"]].copy()
    post_neurons.columns = ["post_root_id", "post_x", "post_y", "post_z"]

    # Join in one step
    randomizable_with_coords = randomizable_connections.merge(
        pre_neurons, on="pre_root_id"
    ).merge(post_neurons, on="post_root_id")

    # Calculate distances vectorized
    print("Calculating distances (vectorized)...")
    randomizable_with_coords["distance"] = np.sqrt(
        (randomizable_with_coords["pre_x"] - randomizable_with_coords["post_x"]) ** 2
        + (randomizable_with_coords["pre_y"] - randomizable_with_coords["post_y"]) ** 2
        + (randomizable_with_coords["pre_z"] - randomizable_with_coords["post_z"]) ** 2
    )

    # Calculate original wiring length
    original_total_length = (
        randomizable_with_coords["distance"] * randomizable_with_coords["syn_count"]
    ).sum()
    print(
        f"Original total wiring length for randomizable connections: {original_total_length}"
    )

    # Create distance bins
    print(f"Creating {bins} distance bins...")
    randomizable_with_coords["bin"] = pd.qcut(
        randomizable_with_coords["distance"], bins, labels=False
    )

    # Shuffle within bins
    print("Shuffling connections within distance bins (vectorized)...")
    shuffled_pre_ids = randomizable_with_coords["pre_root_id"].copy()

    # Replace the shuffling section with this:
    print("Shuffling entire connections within distance bins...")
    shuffled_randomizable = pd.DataFrame()

    # Group by bin and shuffle entire connections
    shuffled_randomizable = pd.DataFrame()
    for bin_id in tqdm(range(bins)):
        bin_group = randomizable_with_coords[randomizable_with_coords["bin"] == bin_id]
        shuffled_bin = shuffle_within_bin(bin_group)
        shuffled_randomizable = pd.concat(
            [
                shuffled_randomizable,
                shuffled_bin[["pre_root_id", "post_root_id", "syn_count"]],
            ]
        )

    # Calculate distance for shuffled connections
    shuffled_with_coords = shuffled_randomizable.merge(
        pre_neurons, on="pre_root_id"
    ).merge(post_neurons, on="post_root_id")

    shuffled_with_coords["distance"] = np.sqrt(
        (shuffled_with_coords["pre_x"] - shuffled_with_coords["post_x"]) ** 2
        + (shuffled_with_coords["pre_y"] - shuffled_with_coords["post_y"]) ** 2
        + (shuffled_with_coords["pre_z"] - shuffled_with_coords["post_z"]) ** 2
    )

    # Calculate shuffled wiring length
    shuffled_total_length = (
        shuffled_with_coords["distance"] * shuffled_with_coords["syn_count"]
    ).sum()

    # Calculate preserved connections wiring length
    preserved_with_coords = preserved_connections.merge(
        pre_neurons, on="pre_root_id"
    ).merge(post_neurons, on="post_root_id")

    preserved_with_coords["distance"] = np.sqrt(
        (preserved_with_coords["pre_x"] - preserved_with_coords["post_x"]) ** 2
        + (preserved_with_coords["pre_y"] - preserved_with_coords["post_y"]) ** 2
        + (preserved_with_coords["pre_z"] - preserved_with_coords["post_z"]) ** 2
    )

    preserved_total_length = (
        preserved_with_coords["distance"] * preserved_with_coords["syn_count"]
    ).sum()

    # Combine preserved and shuffled
    final_connections = pd.concat(
        [
            preserved_connections[["pre_root_id", "post_root_id", "syn_count"]],
            shuffled_randomizable,
        ],
        ignore_index=True,
    )

    # Calculate total length
    total_length = shuffled_total_length + preserved_total_length
    print(f"Shuffled randomizable wiring length: {shuffled_total_length}")
    print(f"Preserved wiring length: {preserved_total_length}")
    print(f"Total wiring length: {total_length}")

    # Calculate ratio
    ratio = shuffled_total_length / original_total_length
    print(f"Ratio of randomized wiring lengths: {ratio:.4f}")
    print(f"Target range: {1-tolerance:.4f} to {1+tolerance:.4f}")

    if not (1 - tolerance <= ratio <= 1 + tolerance):
        print(
            f"Warning: Ratio {ratio:.4f} is outside the target range. You may want to try again with more bins."
        )

    print(f"Total time: {time.time() - start_time:.2f} seconds")

    return final_connections

In [None]:
random_connections = create_length_preserving_random_network(
    connections, nc, bins=20, tolerance=0.1
)

Starting fully vectorized network randomization using pandas...
Identified 11119 retinal neurons and 916 decision neurons
Preserved 140514 connections, will randomize 16707483
Calculating distances (vectorized)...
Original total wiring length for randomizable connections: 1396282370978.2185
Creating 20 distance bins...
Shuffling connections within distance bins (vectorized)...
Shuffling entire connections within distance bins...


100%|██████████| 20/20 [00:04<00:00,  4.44it/s]


Shuffled randomizable wiring length: 1396282370978.2188
Preserved wiring length: 4911623927.495159
Total wiring length: 1401193994905.7139
Ratio of randomized wiring lengths: 1.0000
Target range: 0.9000 to 1.1000
Total time: 18.60 seconds


In [None]:
random_connections.to_csv(os.path.join(PROJECT_ROOT, "new_data", "connections_random3.csv"), index=False)