# Plot histograms

In [None]:
import numpy as np
import brainlit
import scipy
from brainlit.utils import swc
from cloudvolume.exceptions import SkeletonDecodeError
from brainlit.algorithms.trace_analysis.fit_spline import GeometricGraph
from brainlit.algorithms.trace_analysis.spline_fxns import curvature, torsion
import os
from pathlib import Path
import pandas as pd
from networkx.readwrite import json_graph
import json
import matplotlib.pyplot as plt
import torch

## Define brain, find trace data folder

In [None]:
#specify brain1 or brain2 below
brain = "brain1"
#specify length definition below (0-length from root, 1-relative length from root, 2-relative length along segment)
length_def = 1

bins = 2048
root_dir = Path(os.path.abspath('')).parents[1]
experiment_dir = os.path.join(root_dir, "axon_geometry")
data_dir = os.path.join(experiment_dir, "data", brain)
segments_swc_dir = os.path.join(data_dir, "segments_swc")
trace_data_dir = os.path.join(data_dir, "trace_data")
print(f"Directory where swcs reside: {segments_swc_dir}")

## Read trace data

In [None]:
max_id = 300
seg_lengths = []
mean_torsions = []
mean_curvatures = []
d_from_root = []
torsions = []
curvatures = []
for i in np.arange(0, max_id):
    i = int(i)
    trace_data_path = os.path.join(trace_data_dir, "{}.npy".format(i))
    if os.path.exists(trace_data_path) is True:
        trace_data = np.load(trace_data_path, allow_pickle=True)
        print("Loaded segment {}".format(i))

        for node in trace_data:
            seg_lengths.append(node["seg_length"])
            mean_curvatures.append(node["mean_curvature"])
            mean_torsions.append(node["mean_torsion"])
            _curvatures = node["curvature"]
            _torsions = node["torsion"]
            assert len(_curvatures) == len(_torsions)
            m = len(_curvatures)
            if length_def == 0: # length from root
                d = (node["starting_length"] + np.arange(0, m))
            elif length_def == 1: # relative length from root
                d = (node["starting_length"] + np.arange(0, m))/(node["seg_length"]+node["starting_length"])
            elif length_def == 2: # relative length along segment
                d = np.arange(0, m)/node["seg_length"]
            d_from_root.append(d)
            torsions.append(_torsions)
            curvatures.append(_curvatures)
            
seg_lengths = np.array(seg_lengths)
mean_curvatures = np.array(mean_curvatures)
mean_torsions = np.array(mean_torsions)
d_from_root = np.concatenate(d_from_root)
curvatures = np.concatenate(curvatures)
torsions = np.concatenate(torsions)
# remove all root points
masked_d_from_root = np.ma.masked_less(d_from_root, 1e-16)
d_from_root = masked_d_from_root.compressed()
assert len(np.where(d_from_root < 1e-16)[0]) == 0
curvatures = curvatures[~masked_d_from_root.mask]
torsions = torsions[~masked_d_from_root.mask]

log_d_from_root = np.log10(d_from_root)
L = len(log_d_from_root)

## Compute histogram of the curvatures

In [None]:
zero_curvatures_log_d_from_root = log_d_from_root[np.where(curvatures < 1e-16)[0]]
nonzero_curvatures_log_d_from_root = log_d_from_root[np.where(curvatures > 1e-16)[0]]
# compute histogram of zero-curvatures
torch_zero_curvatures_log_d_from_root = torch.from_numpy(zero_curvatures_log_d_from_root)
zero_curvatures_hist_min = torch.min(torch_zero_curvatures_log_d_from_root)
zero_curvatures_hist_max = torch.max(torch_zero_curvatures_log_d_from_root)
zero_curvatures_hist_bin_edges = torch.arange(
    zero_curvatures_hist_min,
    zero_curvatures_hist_max,
    (zero_curvatures_hist_max - zero_curvatures_hist_min)/bins)

zero_curvatures_hist = torch.histc(
    torch_zero_curvatures_log_d_from_root,
    bins=bins,
    min=zero_curvatures_hist_min,
    max=zero_curvatures_hist_max,
)
# compute histogram of non-zero-curvatures (uses GPU)
torch_nonzero_curvatures_log_d_from_root = torch.from_numpy(nonzero_curvatures_log_d_from_root)
nonzero_curvatures_hist_min = torch.min(torch_nonzero_curvatures_log_d_from_root)
nonzero_curvatures_hist_max = torch.max(torch_nonzero_curvatures_log_d_from_root)
nonzero_curvatures_hist_bin_edges = torch.arange(
    nonzero_curvatures_hist_min,
    nonzero_curvatures_hist_max,
    (nonzero_curvatures_hist_max - nonzero_curvatures_hist_min) / bins,
)
nonzero_curvatures_hist = torch.histc(
    torch_nonzero_curvatures_log_d_from_root,
    bins=bins,
    min=nonzero_curvatures_hist_min,
    max=nonzero_curvatures_hist_max
)

alpha_zero_curvatures = len(zero_curvatures_log_d_from_root) / L
alpha_nonzero_curvatures = len(nonzero_curvatures_log_d_from_root) / L
print(alpha_zero_curvatures, alpha_nonzero_curvatures)

## Compute histogram of the torsions (uses GPU)

In [None]:
zero_torsions_log_d_from_root = log_d_from_root[np.where(torsions < 1e-16)[0]]
nonzero_torsions_log_d_from_root = log_d_from_root[np.where(torsions > 1e-16)[0]]
# compute histogram of zero-torsions
torch_zero_torsions_log_d_from_root = torch.from_numpy(zero_torsions_log_d_from_root).to(device)
zero_torsions_hist_min = torch.min(torch_zero_torsions_log_d_from_root)
zero_torsions_hist_max = torch.max(torch_zero_torsions_log_d_from_root)
zero_torsions_hist_bin_edges = torch.arange(
    zero_torsions_hist_min,
    zero_torsions_hist_max,
    (zero_torsions_hist_max - zero_torsions_hist_min)/bins)

zero_torsions_hist = torch.histc(
    torch_zero_torsions_log_d_from_root,
    bins=bins,
    min=zero_torsions_hist_min,
    max=zero_torsions_hist_max,
)
# compute histogram of non-zero-torsions (uses GPU)
torch_nonzero_torsions_log_d_from_root = torch.from_numpy(nonzero_torsions_log_d_from_root).to(device)
nonzero_torsions_hist_min = torch.min(torch_nonzero_torsions_log_d_from_root)
nonzero_torsions_hist_max = torch.max(torch_nonzero_torsions_log_d_from_root)
nonzero_torsions_hist_bin_edges = torch.arange(
    nonzero_torsions_hist_min,
    nonzero_torsions_hist_max,
    (nonzero_torsions_hist_max - nonzero_torsions_hist_min) / bins,
)
nonzero_torsions_hist = torch.histc(
    torch_nonzero_torsions_log_d_from_root,
    bins=bins,
    min=nonzero_torsions_hist_min,
    max=nonzero_torsions_hist_max
)

alpha_zero_torsions = len(zero_torsions_log_d_from_root) / L
alpha_nonzero_torsions = len(nonzero_torsions_log_d_from_root) / L
print(alpha_zero_torsions, alpha_nonzero_torsions)

## Plot figure

In [None]:
fig = plt.figure(figsize=(21, 7))
axes = fig.subplots(1, 2)
GRAY = "#999999"
TITLE_TYPE_SETTINGS = {"fontname": "Arial", "size": 20}
SUP_TITLE_TYPE_SETTINGS = {"fontname": "Arial", "size": 24}
plt.rc("font", family="Arial", size=20)

ax = axes[0]
ax.spines["bottom"].set_color(GRAY)
ax.spines["top"].set_color(GRAY)
ax.spines["right"].set_color(GRAY)
ax.spines["left"].set_color(GRAY)
ax.tick_params(axis="both", colors=GRAY, labelsize="large")

if length_def == 0: # length from root
    xlabel = r"$\log$ distance from cell body ($\mu m$)"
elif length_def == 1: # relative length from root
    xlabel = r"$\log$ relative distance from cell body"
elif length_def == 2: # relative length along segment
    xlabel = r"$\log$ relative distance along segment"

ax.bar(zero_curvatures_hist_bin_edges.cpu().numpy(), zero_curvatures_hist.numpy() / L, label=r"$\mathcal{k} = 0$", alpha=0.5)
ax.bar(nonzero_curvatures_hist_bin_edges.cpu().numpy(), nonzero_curvatures_hist.numpy() / L, label=r"$\mathcal{k} > 0$", alpha=0.5)

ax.set_title(r"Curvature ($\alpha = %.4f$)" % alpha_zero_curvatures)
ax.set_xlabel(xlabel, fontsize=24)
ax.set_ylabel(r"pdf", fontsize=24)
leg = ax.legend(loc=2)
leg.get_frame().set_edgecolor(GRAY)

ax = axes[1]
ax.spines["bottom"].set_color(GRAY)
ax.spines["top"].set_color(GRAY)
ax.spines["right"].set_color(GRAY)
ax.spines["left"].set_color(GRAY)
ax.tick_params(axis="both", colors=GRAY, labelsize="large")

ax.bar(zero_torsions_hist_bin_edges.cpu().numpy(), zero_torsions_hist.cpu().numpy() / L, label=r"$\tau = 0$", alpha=0.5)
ax.bar(nonzero_torsions_hist_bin_edges.cpu().numpy(), nonzero_torsions_hist.cpu().numpy() / L, label=r"$\tau > 0$", alpha=0.5)

ax.set_title(r"Torsion ($\alpha = %.4f$)" % alpha_zero_torsions)
ax.set_xlabel(xlabel, fontsize=24)
ax.set_ylabel(r"pdf", fontsize=24)
leg = ax.legend(loc=2)
leg.get_frame().set_edgecolor(GRAY)

title = "Brain 1" if brain == "brain1" else "Brain 2"
fig.suptitle(title)

plt.savefig(os.path.join(experiment_dir, "figures", "{}_histograms_from_root_{}.jpg".format(brain, length_def)))
plt.savefig(os.path.join(experiment_dir, "figures", "{}_histograms_from_root_{}.eps".format(brain, length_def)))