In [None]:
from tqdm import tqdm
import pickle
from pathlib import Path
import os

from skimage import io, measure
import numpy as np
import h5py
from brainlit.preprocessing import image_process
from brainlit.algorithms.connect_fragments import dynamic_programming_viterbi2
import scipy.ndimage as ndi
from sklearn.metrics import pairwise_distances_argmin_min
import napari
import networkx as nx

# Read Data

In [None]:
root_dir = Path(os.path.abspath("")).parents[0]
data_dir = os.path.join(root_dir, "data", "example")
im_path = os.path.join(data_dir, "image.tif")
probs_path = os.path.join(data_dir, "probabilities.h5")
coords_path = os.path.join(data_dir, "manual_coords.csv")

res = [0.3, 0.3, 1]

im_og = io.imread(im_path, plugin="tifffile")
print(f"Image shape: {im_og.shape}")

coords = np.genfromtxt(coords_path, delimiter=",")
coords = coords.astype(int)
coords_list = list(coords)
coords_list = [list(c) for c in coords_list]
coords_list.reverse()
print(f"coords shape: {coords.shape}")
soma_coords = [list(coords[0, :])]
axon_coords = [list(coords[-1, :])]

f = h5py.File(probs_path, "r")
pred = f.get("exported_data")
pred = pred[:, :, :, 1]
im_processed = pred

threshold = 0.9  # 0.1
labels = measure.label(im_processed > threshold)

# Process Components into Fragments

In [None]:
new_labels = image_process.split_frags(
    soma_coords, labels, im_processed, threshold, res
)

_, axon_lbls = image_process.label_points(new_labels, axon_coords, res)
axon_lbl = axon_lbls[0]
axon_mask = new_labels == axon_lbl

_, soma_lbls = image_process.label_points(new_labels, soma_coords, res)
soma_lbl = soma_lbls[0]
soma_mask = new_labels == soma_lbl

## View Data and Processed Labels

In [None]:
viewer = napari.Viewer(ndisplay=3)

viewer.add_image(im_og)
viewer.add_labels(labels)
viewer.add_labels(new_labels)
viewer.add_labels(axon_mask)
viewer.add_labels(soma_mask)

# Reconstruct Axon

In [None]:
mpnp = dynamic_programming_viterbi2.most_probable_neuron_path(
    image=im_og.astype(float),
    labels=new_labels,
    soma_lbls=soma_lbls,
    resolution=(0.3, 0.3, 1),
    coef_dist=10,
    coef_curv=1000,
)
mpnp.frags_to_lines()
mpnp.reset_dists(type="all")
mpnp.compute_all_costs_dist(
    point_point_func=mpnp.point_point_dist, point_blob_func=mpnp.point_blob_dist
)
mpnp.compute_all_costs_int()
mpnp.create_nx_graph()

## Choose which state to start from

In [None]:
axon_lbl = axon_lbls[0]

start1 = mpnp.comp_to_states[axon_lbl][0]
start2 = mpnp.comp_to_states[axon_lbl][1]
end_state = mpnp.comp_to_states[soma_lbl][0]

# In this example, I know which state we should start from. Otherwise, the user should examine state1 and state2 and decide
start = start2

In [None]:
path_states = nx.shortest_path(mpnp.nxGraph, start, end_state, weight="weight")

# Plot Result

In [None]:
path_comps = []
for state in path_states:
    path_comps.append(mpnp.state_to_comp[state][1])
print(f"path sequence: {path_states}")
print(f"component sequence: {path_comps}")

path_mask = 0 * new_labels
for i, label in enumerate(path_comps):
    path_mask[new_labels == label] = i + 1

soma_mask = 0 * new_labels
for soma_lbl in mpnp.soma_lbls:
    soma_mask[new_labels == soma_lbl] = soma_lbl

viewer = napari.Viewer(ndisplay=3)
viewer.add_image(mpnp.image)
viewer.add_labels(new_labels)
viewer.add_labels(path_mask)
viewer.add_labels(soma_mask)
viewer.add_labels(new_labels == axon_lbl)

viewer.add_points([axon_coords[0]], face_color="red", size=10)

lines = []
cumul_cost = 0
for s, state in enumerate(path_states):
    if s > 0:
        dist_cost = mpnp.cost_mat_dist[path_states[s - 1], state]
        int_cost = mpnp.cost_mat_int[path_states[s - 1], state]
        cumul_cost += dist_cost + int_cost
        print(
            f"Trans. #{s}: dist cost state {path_states[s-1]}->state {state}, comp {mpnp.state_to_comp[path_states[s-1]][1]}->comp {mpnp.state_to_comp[state][1]}: {dist_cost:.2f}, int cost: {int_cost:.2f}, cum. cost: {cumul_cost:.2f}"
        )
    if mpnp.state_to_comp[state][0] == "fragment":
        lines.append(list(mpnp.state_to_comp[state][2]["coord1"]))
        lines.append(list(mpnp.state_to_comp[state][2]["coord2"]))
    elif mpnp.state_to_comp[path_states[s - 1]][0] == "fragment":
        lines.append(
            list(mpnp.state_to_comp[path_states[s - 1]][2]["soma connection point"])
        )
lines.insert(0, coords_list[0])
lines.append(coords_list[-1])
viewer.add_shapes(lines, shape_type="path", edge_color="blue", edge_width=2)
viewer.add_shapes(coords_list, shape_type="path", edge_color="green", edge_width=2)

viewer.camera.angles = [0, -90, 180]
napari.run()