In this notebook, starting from a lineage-labeled neuron, we collect lineages nearby and check to see if the annotation is indeed correct, based on a naive Bayes classifier trained on the local environment around a particular lineage entry point.

In [None]:
import os

import catalysis as cat
import catalysis.pynblast as pynblast
import catalysis.plt as catplt
import catalysis.transform as transform
import catalysis.completeness as completeness
import catalysis.lineage_classifier as lineage_classifier

import plotly.offline as py
import plotly.graph_objs as go

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as cl

import numpy as np
import scipy as sp
import pandas as pd
import networkx as nx
import re
import dill as pickle

from IPython.core.display import display, HTML

from itertools import chain, cycle
from sklearn import cluster
from sklearn.neighbors.kde import KernelDensity

from importlib import reload

HERE = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

In [None]:
l1data = cat.CatmaidDataInterface.from_json(os.environ["CATMAID_CREDENTIALS"])

In [None]:
adult_fly_f = pd.read_csv(os.path.join(HERE, "data/smat_jefferis.csv"), delimiter=" ")
smat = pynblast.ScoreMatrixLookup.from_dataframe(adult_fly_f)
# Do this to reduce things to L1 volume size and change scale to nm, based on scale up observed in the L1/L3 data papers and desired properties
smat.d_range = smat.d_range * 1000 / 4

with open(os.path.join(HERE, "data/Brain_Lineage_Landmarks_EMtoEM_ProjectSpace.csv") as fid:
    lin_landmarks = pd.read_csv(fid)

side_name = ["l_loc", "r_loc"]

wb_match = re.compile("\*.*akira")
lineage_parser = re.compile("\*(?P<group>.*?)_(?P<instance>[rl]) akira")
hemilateral_groups = l1data.match_groups_from_select_annotations(
    wb_match, lineage_parser
)

all_lins = []
for lin in sorted(list(hemilateral_groups.keys())):
    for side in hemilateral_groups[lin]:
        all_lins.append(hemilateral_groups[lin][side])

lin_df = lineage_classifier.lineage_table(hemilateral_groups, lin_landmarks, side_name)

---
## Continued thoughts on bundling approach.

We want to use NBLAST to identify what lineage neurons should belong to. While in principle, lineage bundles are fairly clear, in practice that is not the case at all. In some instances, you really need the EM to guide you into seeing who genuinely bundles with whom. Part of this may be that the NBLAST cost function needs to be re-jiggered for this purpose. Bundled neurites are vastly more similar than is the case with whole arbors in transformed spaces, after all. The first step really should be to re-compute the NBLAST scores in this particular domain. However, there will still be outliers, and the clustering should be softer than a simple k-means would properly give you, giving you a ranked list of candidates.

For a given neuron $\textbf{x}$ and lineage $L_i \in \{L_i\}$, we want to find the lineage $i$ which maximizes $P( \textbf{x} \in L_i | <\textbf{x},\textbf{y}>_{\textbf{y} \in L_i})$, where $<\textbf{x},\textbf{y}>_{y \ in L_i}$ indicates the mean NBLAST distance to all neurons in $L_i$. 
After trivial application of Bayes's rule, the probability of $\textbf{x} \in L_i$ given the evidence is:
$$
P( \textbf{x} \in L_i \mid <\textbf{x},\textbf{y}>_{\textbf{y} \in L_i}) = 
    P\left( <\textbf{x},\textbf{y}>_{\textbf{y} \in L_i} \mid \textbf{x} \in L_i \right)
    \frac{P\left( \textbf{x} \in L_i \right)}{P\left( <\textbf{x},\textbf{y}>_{\textbf{y} \in L_i} \right)}
$$

We need to do comparison for those lineages for a specific region determined by the neuron at hand, since some lineages have more similar lineages than others. This is effectively like hand-setting distant lineages to probability 0 and not bothering with the computations.

By pre-computing the NBLAST distance scores, all of these values can be computed quickly, leaving only $N \times 1$ NBLAST comparisons (where N is the number of proximate neurites, for which the dotprops themselves can be pre-computed). Retraining can happen on that matrix.

* We can estimate $P( <\mathbf{x},\mathbf{y}>_{\mathbf{y} \in L_i} \mid \mathbf{x} \in L_i )$ for each $i$ (using all $\mathbf{x} in L_i$).
* We can estimate $P( <\mathbf{x},\mathbf{y}>_{\mathbf{y} \in L_i} \mid \mathbf{x} \notin L_i )$ for each $i$ (using all $\mathbf{x} not in L_i$).
* Assume $P(\mathbf{x} \in L_i )$ is flat and cancels out in ratio comparisions.

Now compute the ratio of the likelihood that a neuron is in the lineage given the data to the likelihood that it is not.
$\fract{P( <\mathbf{x},\mathbf{y}>_{\mathbf{y} \in L_i} \mid \mathbf{x} \in L_i )}{P( <\mathbf{x},\mathbf{y}>_{\mathbf{y} \in L_i} \mid \mathbf{x} \noin L_i )} $ for all proximate lineages $i$.

For each lineage annotation, we're going to go through each neuron and train the above on all other neurons. If the most likely lineage is still the top pick, then we're good and move on. If not, we double check by plotting the lineages or, potentially, going to the EM data directly to look at bundling properties.

In [None]:
lin_ind = 15
side = "l"

nearby_neuron_radius = 2500
nearby_lineage_radius = 4 * nearby_neuron_radius
min_cable = 5000
l_span = 30 * 1000
resample_distance = 1000
reroot_skeletons = True

In [None]:
base_lin = sorted(list(hemilateral_groups.keys()))[lin_ind]

# Compute distances between our base location and other lineages.
if side == "l":
    xyz0 = hemilateral_groups[base_lin]["l_loc"]
elif side == "r":
    xyz0 = hemilateral_groups[base_lin]["r_loc"]

ds = sp.spatial.distance.cdist(xyz0, np.array(list(lin_df.xyz.values)))[0]
lin_df_sp = lin_df.assign(ds=ds)

readable_side = {"l": "left side", "r": "right side"}
print("Working on {}, {}".format(base_lin, readable_side[side]))

---
### Find nearby lineages

This uses 'nearby_lineage_radius' to find lineage entry points near the base neuron. Optionally, reroot skeletons based on the 

In [None]:
# Find all annotation_ids corresponding to lineages

all_lin_ids = l1data.parse_annotation_list(all_lins, output="ids")

# Find lineages that are within `nearby_neuron_radius` of the base point
rel_lins = [
    hemilateral_groups[row[1].lin][row[1].side[0]]
    for row in lin_df_sp[lin_df_sp.ds < nearby_lineage_radius].iterrows()
]
base_lin_ids = l1data.get_ids_from_annotations(
    hemilateral_groups[base_lin][side], flatten=True
)
near_lin_ids = l1data.get_ids_from_annotations(rel_lins, flatten=True)

# Make sure that the skeletons have proper roots
if reroot_skeletons:
    l1data.reroot_neurons_to_soma(near_lin_ids)

# Find which lineage annotations are associated with which proximate skeletons
multi_anno_ids = []
for skid in base_lin_ids:
    sk_annos = set(all_lin_ids).intersection(
        set(l1data.get_annotations_for_objects([skid]))
    )
    if len(sk_annos) > 1:
        multi_anno_ids.append(skid)
        print("{} has multiple lineage annotations!".format(skid))
        l1data.url_to_neurons(skid)

---
### Now we need to compute the base probabilities for the lineage groups in this particular region.
*This could be fixed by pre-computing all lineage-lineage nblast scores.*

In [None]:
near_lin_nrns = cat.NeuronList.from_id_list(near_lin_ids, l1data)
print("The nearby lineages are: {}".format(rel_lins))
# Compute dotprops for NBLAST. Make one longer than the other so that targets of queries don't have an artificial truncation.
initial_seg_dotprops_short = lineage_classifier.compute_initial_segment_dotprops(
    near_lin_nrns, l_span, resample_distance
)
initial_seg_dotprops_long = lineage_classifier.compute_initial_segment_dotprops(
    near_lin_nrns, 2 * l_span, resample_distance
)

# Get convenience interpreters.
# Mapping from lineage to skids beloning to it. Remove problem ids with multiple annotations.
lin_dict = l1data.get_ids_from_annotations(rel_lins)

lin_name2id = {name: l1data.parse_annotation_list(name)[0] for name in rel_lins}
lin_id2name = {lin_name2id[name]: name for name in lin_name2id}

skid2lin = {}
skid2lin_name = {}
for lin in lin_dict:
    for skid in lin_dict[lin]:
        skid2lin[skid] = lin

In [None]:
# This could be pre-computed
lineage_lineage_nblast = pynblast.nblast_neurons(
    smat,
    nrns_q=initial_seg_dotprops_short,
    nrns_t=initial_seg_dotprops_long,
    as_dotprop=True,
    normalize=True,
).pivot(index="Queries", columns="Targets", values="S")

#### Compute likelihood ratios for each neuron in the base lineage, trained on all other neurons in that lineage (i.e. leave one out).

In [None]:
problem_ids = []
problem_table = []
for rel_skid in base_lin_ids:
    P_match, P_unmatch = lineage_classifier.compute_conditional_distributions(
        lineage_lineage_nblast, lin_dict, lin_name2id, rel_lins, skip_skids=rel_skid
    )
    kdes_match, kdes_match_norm = lineage_classifier.fit_gaussian_kdes(P_match)
    kdes_unmatch, kdes_unmatch_norm = lineage_classifier.fit_gaussian_kdes(P_unmatch)
    suggested_matches = lineage_classifier.lineage_likelihood_ratios(
        rel_skid,
        lineage_lineage_nblast,
        lin_dict,
        lin_name2id,
        rel_lins,
        kdes_match,
        kdes_match_norm,
        kdes_unmatch,
        kdes_unmatch_norm,
    )

    if (
        suggested_matches[
            suggested_matches.Lineage == hemilateral_groups[base_lin][side]
        ].FractionOfBest.values
        < 1
    ):
        problem_table.append(suggested_matches)
        problem_ids.append(rel_skid)

if len(problem_table) > 1:
    problem_table = pd.concat(problem_table, ignore_index=True)
    print("{} skeleton ids need to be checked".format(len(problem_ids)))
    focused_table = problem_table[
        (problem_table.FractionOfBest == 1)
        | (problem_table.Lineage == hemilateral_groups[base_lin][side])
    ].reset_index(drop=True)
elif len(problem_table) == 1:
    problem_table = problem_table[0]
    print("{} skeleton ids need to be checked".format(len(problem_ids)))
    focused_table = problem_table[
        (problem_table.FractionOfBest == 1)
        | (problem_table.Lineage == hemilateral_groups[base_lin][side])
    ].reset_index(drop=True)
else:
    print("No skeleton ids need to be checked")
    focused_table = pd.DataFrame(
        {"FractionOfBest": [], "LikelihoodRatio": [], "Lineage": [], "SkeletonId": []}
    )
focused_table

### Plot the candidate matches to confirm suggestions.

In [None]:
# Plug in a row number from the above table, corresponding to the entry with the 1.00 in fraction of best (this will be even)
row_num = 3

rel_skid = focused_table[focused_table.index == row_num].SkeletonId.values[0]
suggest_lin = focused_table[focused_table.index == row_num].Lineage.values[0]
if rel_skid in multi_anno_ids:
    display(HTML("<mark><b>NEURON HAS MULIPLE LINEAGE ANNOTATIONS!</b></mark>"))
l1data.url_to_neurons(rel_skid)
display(HTML("<b>Current lineage:</b> {}".format(hemilateral_groups[base_lin][side])))
display(
    HTML('<b><font color="blue">Suggested lineage:</b> {}</font>'.format(suggest_lin))
)


data = []
data.append(
    catplt.path_data(
        initial_seg_dotprops_long[rel_skid][:, 0:3],
        color=(0.9, 0, 0.1),
        width=5,
        name=near_lin_nrns[rel_skid].name,
    )
)

# Plot the suggested lineage
if suggest_lin != hemilateral_groups[base_lin][side]:
    for skid in lin_dict[lin_name2id[suggest_lin]]:
        data.append(
            catplt.path_data(
                initial_seg_dotprops_long[skid][:, 0:3],
                color=(0.1, 0.1, 0.9),
                width=1,
                name=suggest_lin,
            )
        )

# Plot the currently assigned lineage
for skid in lin_dict[lin_name2id[hemilateral_groups[base_lin][side]]]:
    data.append(
        catplt.path_data(
            initial_seg_dotprops_long[skid][:, 0:3],
            color=(0.5, 0.5, 0.5),
            width=1,
            name=hemilateral_groups[base_lin][side],
        )
    )

layout = go.Layout({"showlegend": False, "width": 800, "height": 800})
fig = go.Figure(data=data, layout=layout)
py.iplot(fig)

---
### Example code below to plot all the local environment lineages.

In [None]:
clrs = plt.get_cmap("Set1").colors
color_cycle = cycle(plt.get_cmap("Dark2").colors)

clr_dict = {}

for lb in np.unique(np.unique(list(lin_dict.keys()))):
    clr_dict[lb] = next(color_cycle)

data = []
for skid in initial_seg_dotprops:
    lb = skid2lin[skid]
    data.append(
        catplt.path_data(
            initial_seg_dotprops[skid][:, 0:3],
            color=clr_dict[lb],
            width=4,
            name=lin_id2name[lb],
        )
    )

layout = go.Layout({"showlegend": False, "width": 500, "height": 500})
fig = go.Figure(data=data, layout=layout)
py.iplot(fig)