In [65]:
from pathlib import Path
import matplotlib.pyplot as plt
from cloudvolume import CloudVolume
from mpl_toolkits.mplot3d import Axes3D
from brainlit.algorithms.trace_analysis.fit_spline import GeometricGraph
from brainlit.utils.Neuron_trace import NeuronTrace
from scipy.interpolate import splev
import numpy as np
import networkx as nx
from scipy.spatial import cKDTree

%matplotlib qt

## convert to SWC

In [104]:
def skel_to_graph(skel):
    G = nx.Graph()
    for v_n, vertex in enumerate(skel.vertices):
        G.add_node(v_n, loc=vertex)

    edges = [(e[0], e[1]) for e in skel.edges]
    G.add_edges_from(edges)

    return G


def smooth_graph(G):
    new_locs = {}
    av_count = 0

    all_locs = []
    for node in G.nodes:
        all_locs.append(G.nodes[node]["loc"])
    all_locs = np.array(all_locs)
    kdt = cKDTree(all_locs)
    # kdt.query_ball_point()

    for node in G.nodes:
        deg = G.degree(node)
        if deg == 2:
            # nbrs = kdt.query_ball_point(G.nodes[node]['loc'], r=10000)
            # nbrs = G.neighbors(node)
            nbrs = nx.dfs_tree(G, source=node, depth_limit=20)
            locs = [G.nodes[n]["loc"] for n in nbrs]
            dists = [
                np.linalg.norm(np.subtract(loc, G.nodes[node]["loc"])) for loc in locs
            ]
            weights = [1 if dist < 25000 else 0 for dist in dists]
            # locs += [G.nodes[node]['loc']]
            locs = np.array(locs)
            # weights += [1]
            if np.sum(weights) > 1:
                av_count += 1

            new_loc = np.average(locs, axis=0, weights=weights)
            new_locs[node] = new_loc

    for node in new_locs.keys():
        G.nodes[node]["loc"] = new_locs[node]

    print(f"{av_count} averaged nodes")
    return G


def graph_to_vertices(G):
    vertices = []
    for node in G.nodes:
        vertices.append(G.nodes[node]["loc"])
    return np.array(vertices)

In [105]:
G = nx.Graph()
G.add_node(0, loc=[0, 0, 0])
G.add_node(1, loc=[1, 2, 1])
G.add_node(2, loc=[10000, 10000, 10000])
G.add_node(3, loc=[10000, 10000, 10000])
G.add_node(4, loc=[10000, 10000, 10000])
G.add_edge(0, 1)
G.add_edge(1, 2)
G.add_edge(2, 3)
G.add_edge(3, 4)

G = smooth_graph(G)
G.nodes[1]

3 averaged nodes


In [106]:
dir = Path("/Users/thomasathey/Documents/mimlab/mouselight/kolodkin/sriram/misc")
subdirs = ["220-p29-brain1", "220-p29-brain2", "adipo-brain1-im3"]

for subdir in subdirs:
    trace_dir = dir / subdir / "traces"
    vol = CloudVolume("precomputed://file://" + str(trace_dir))
    for skel_id in range(10):
        print(skel_id)
        try:
            skel = vol.skeleton.get(skel_id)
            G = skel_to_graph(skel)
            G = smooth_graph(G)
            vertices = graph_to_vertices(G)
            skel.vertices = vertices
            skel.vertex_types = skel.radii
            txt = skel.to_swc()
            with open(dir / subdir / f"{skel_id}_smoothed.swc", "w") as f:
                f.write(txt)
        except:
            print(f"{skel_id} invalid for {subdir}")
            break

0
4216 averaged nodes
1
9498 averaged nodes
2
9705 averaged nodes
3
10974 averaged nodes
4
14124 averaged nodes
5
6094 averaged nodes
6
9523 averaged nodes
7
7 invalid for 220-p29-brain1
0
5411 averaged nodes
1
6990 averaged nodes
2
7081 averaged nodes
3
6953 averaged nodes
4
6568 averaged nodes
5
7419 averaged nodes
6
6685 averaged nodes
7
6720 averaged nodes
8
5809 averaged nodes
9
9 invalid for 220-p29-brain2
0
846 averaged nodes
1
3469 averaged nodes
2
2219 averaged nodes
3
2578 averaged nodes
4
1875 averaged nodes
5
5 invalid for adipo-brain1-im3


## Plot

In [108]:
dir = Path("/Users/thomasathey/Documents/mimlab/mouselight/kolodkin/sriram/misc")
subdirs = ["220-p29-brain1", "220-p29-brain2", "adipo-brain1-im3"]

subdir_choice = subdirs[0]
skel_id_choice = 1

for subdir in subdirs:
    if subdir != subdir_choice:
        continue
    trace_dir = dir / subdir
    for skel_id in range(10):
        if skel_id != skel_id_choice:
            continue

        swc_path = trace_dir / f"{skel_id}_smoothed.swc"
        swc_trace = NeuronTrace(path=str(swc_path))
        df_swc_offset_neuron = swc_trace.get_df()

        print("Loaded segment {}".format(skel_id))
        G = GeometricGraph(df=df_swc_offset_neuron, remove_duplicates=True)
        print(f"Fitting spline tree")
        spline_tree = G.fit_spline_tree_invariant()
        print("plotting")
        ax = plt.figure().add_subplot(projection="3d")

        for j, node in enumerate(spline_tree.nodes):
            spline = spline_tree.nodes[node]
            tck, u_um = spline["spline"]
            y = splev(u_um, tck)

            ax.plot(y[0], y[1], y[2], c="blue", linewidth=0.5)

        # ax.set_axis_off()
        plt.show()
        break
    break



Loaded segment 1
Fitting spline tree




plotting
