# Plot histograms

In [None]:
import numpy as np
import brainlit
import scipy
from brainlit.utils import swc
from cloudvolume.exceptions import SkeletonDecodeError
from brainlit.algorithms.trace_analysis.fit_spline import GeometricGraph
from brainlit.algorithms.trace_analysis.spline_fxns import curvature, torsion
import os
from pathlib import Path
import pandas as pd
from networkx.readwrite import json_graph
import json
import matplotlib.pyplot as plt
from sklearn.neighbors import KernelDensity
import torch

## Define brain, find trace data folder

In [None]:
# specify 1 or 2 below
brain_id = 1
brain = "brain{}".format(brain_id)
root_dir = Path(os.path.abspath('')).parents[1]
experiment_dir = os.path.join(root_dir, "axon_geometry")
data_dir = os.path.join(experiment_dir, "data", brain)
segments_swc_dir = os.path.join(data_dir, "segments_swc")
trace_data_dir = os.path.join(data_dir, "trace_data")
print(f"Directory where swcs reside: {segments_swc_dir}")

## Read trace data

In [None]:
max_id = 300
seg_lengths = []
mean_torsions = []
mean_curvatures = []
d_from_root = []
torsions = []
curvatures = []
for i in np.arange(0, max_id):
    i = int(i)
    trace_data_path = os.path.join(trace_data_dir, "{}.npy".format(i))
    if os.path.exists(trace_data_path) is True:
        trace_data = np.load(trace_data_path, allow_pickle=True)
        print("Loaded segment {}".format(i))

        for node in trace_data:
            seg_lengths.append(node["seg_length"])
            mean_curvatures.append(node["mean_curvature"])
            mean_torsions.append(node["mean_torsion"])
            
seg_lengths = np.array(seg_lengths)
mean_curvatures = np.array(mean_curvatures)
mean_torsions = np.abs(np.array(mean_torsions))

# Define helper variables

In [None]:
log_seg_lengths = np.log10(seg_lengths)
min_log_seg_length = min(log_seg_lengths)
max_log_seg_length = max(log_seg_lengths)
xx = np.linspace(min_log_seg_length, max_log_seg_length, 1000)[:, np.newaxis]

## Compute KDE of the curvatures

In [None]:
zero_curvatures_log_seg_lengths = log_seg_lengths[np.where(mean_curvatures < 1e-16)[0]]
nonzero_curvatures_log_seg_lengths = log_seg_lengths[
    np.where(mean_curvatures > 1e-16)[0]
]
zero_kde = KernelDensity(kernel="gaussian", bandwidth=0.1).fit(
    zero_curvatures_log_seg_lengths[:, np.newaxis]
)
nonzero_kde = KernelDensity(kernel="gaussian", bandwidth=0.25).fit(
    nonzero_curvatures_log_seg_lengths[:, np.newaxis]
)
zero_log_dens = zero_kde.score_samples(xx)
nonzero_log_dens = nonzero_kde.score_samples(xx)
alpha_zero_curvatures = len(zero_curvatures_log_seg_lengths) / len(seg_lengths)
alpha_nonzero_curvatures = len(nonzero_curvatures_log_seg_lengths) / len(seg_lengths)
print(alpha_zero_curvatures, alpha_nonzero_curvatures)
zero_curvatures_norm_pdf = alpha_zero_curvatures * np.exp(zero_log_dens)
nonzero_curvatures_norm_pdf = alpha_nonzero_curvatures * np.exp(nonzero_log_dens)

## Compute KDE of the torsions

In [None]:
zero_torsions_log_seg_lengths = log_seg_lengths[np.where(mean_torsions < 1e-16)[0]]
nonzero_torsions_log_seg_lengths = log_seg_lengths[np.where(mean_torsions > 1e-16)[0]]
zero_kde = KernelDensity(kernel="gaussian", bandwidth=0.1).fit(
    zero_torsions_log_seg_lengths[:, np.newaxis]
)
nonzero_kde = KernelDensity(kernel="gaussian", bandwidth=0.25).fit(
    nonzero_torsions_log_seg_lengths[:, np.newaxis]
)
zero_log_dens = zero_kde.score_samples(xx)
nonzero_log_dens = nonzero_kde.score_samples(xx)
alpha_zero_torsions = len(zero_torsions_log_seg_lengths) / len(seg_lengths)
alpha_nonzero_torsions = len(nonzero_torsions_log_seg_lengths) / len(seg_lengths)
print(alpha_zero_torsions, alpha_nonzero_torsions)
zero_torsions_norm_pdf = alpha_zero_torsions * np.exp(zero_log_dens)
nonzero_torsions_norm_pdf = alpha_nonzero_torsions * np.exp(nonzero_log_dens)

## Plot figure

In [None]:
fig = plt.figure(figsize=(21, 7))
axes = fig.subplots(1, 2)
GRAY = "#999999"
TITLE_TYPE_SETTINGS = {"fontname": "Arial", "size": 20}
SUP_TITLE_TYPE_SETTINGS = {"fontname": "Arial", "size": 24}
plt.rc("font", family="Arial", size=20)

log_seg_lengths = np.log10(seg_lengths)
min_log_seg_length = min(log_seg_lengths)
max_log_seg_length = max(log_seg_lengths)
xx = np.linspace(min_log_seg_length, max_log_seg_length, 1000)[:, np.newaxis]

ax = axes[0]
ax.spines["bottom"].set_color(GRAY)
ax.spines["top"].set_color(GRAY)
ax.spines["right"].set_color(GRAY)
ax.spines["left"].set_color(GRAY)
ax.tick_params(axis="both", colors=GRAY, labelsize="large")

ax.fill_between(xx.squeeze(), 0, zero_curvatures_norm_pdf, alpha=0.7, label=r"$\mathcal{k} = 0$")
ax.fill_between(
    xx.squeeze(), 0, nonzero_curvatures_norm_pdf, alpha=0.7, label=r"$\mathcal{k} > 0$"
)

mask = np.array(
    [
        False if zero_ > nonzero_ else True
        for zero_, nonzero_ in zip(zero_curvatures_norm_pdf, nonzero_curvatures_norm_pdf)
    ]
)
ids = np.where(mask == True)[0]
xx_dashed = xx.squeeze()[ids]
zero_curvatures_norm_pdf_dashed = zero_curvatures_norm_pdf[ids]
ax.plot(xx_dashed.squeeze(), zero_curvatures_norm_pdf_dashed, "--")


ax.set_title(r"Curvature ($\alpha = %.2f$)" % alpha_zero_curvatures)
ax.set_xlabel(r"$\log$ segment length ($\mu m$)", fontsize=24)
ax.set_ylabel(r"pdf", fontsize=24)
leg = ax.legend(loc=1)
leg.get_frame().set_edgecolor(GRAY)
ax.set_xticks([1, 2, 3, 4])

ax = axes[1]
ax.spines["bottom"].set_color(GRAY)
ax.spines["top"].set_color(GRAY)
ax.spines["right"].set_color(GRAY)
ax.spines["left"].set_color(GRAY)
ax.tick_params(axis="both", colors=GRAY, labelsize="large")

ax.fill_between(xx.squeeze(), 0, zero_torsions_norm_pdf, alpha=0.7, label=r"$\tau = 0$")
ax.fill_between(xx.squeeze(), 0, nonzero_torsions_norm_pdf, alpha=0.7, label=r"$\tau > 0$")

mask = np.array(
    [
        False if zero_ > nonzero_ else True
        for zero_, nonzero_ in zip(zero_torsions_norm_pdf, nonzero_torsions_norm_pdf)
    ]
)
ids = np.where(mask == True)[0]
xx_dashed = xx.squeeze()[ids]
zero_torsions_norm_pdf_dashed = zero_torsions_norm_pdf[ids]
ax.plot(xx_dashed.squeeze(), zero_torsions_norm_pdf_dashed, "--")

ax.set_title(r"Torsion ($\alpha = %.2f$)" % alpha_zero_torsions)
ax.set_xlabel(r"$\log$ segment length ($\mu m$)", fontsize=24)
ax.set_ylabel(r"pdf", fontsize=24)
leg = ax.legend(loc=1)
leg.get_frame().set_edgecolor(GRAY)
ax.set_xticks([1, 2, 3, 4])

fig.suptitle("Brain {}".format(brain_id))

plt.savefig(os.path.join(experiment_dir, "figures/{}_histograms.jpg".format(brain)))
plt.savefig(os.path.join(experiment_dir, "figures/{}_histograms.eps".format(brain)))