In [None]:
%matplotlib widget

In [None]:
import matplotlib.pyplot as plt
from seaborn import husl_palette
from em_reconstruction.loading import load_skeleton_from_kzip
from pathlib import Path

In [None]:
from bg_atlasapi.bg_atlas import BrainGlobeAtlas
from bg_space import AnatomicalSpace
from em_reconstruction.plotting import get_mpiref_coords

In [None]:
import numpy as np

In [None]:
from lotr import plotting as pltltr
from lotr import DATASET_LOCATION
from bg_atlasapi.core import Atlas
COLS = pltltr.COLS

atlas = Atlas(DATASET_LOCATION.parent / "anatomy" / "ipn_zfish_0.5_um_v1.6")

In [None]:
path = r"\\Funes\Shared\experiments\E0076_EM_reconstructions\skeletons_converted\all_ahb.k.zip"
save_dir = Path(r"C:\Users\lavian\Desktop\EM plots")

In [None]:
cell_list = load_skeleton_from_kzip(path)
colors = husl_palette(len(cell_list))

In [None]:
# load MPI reference
mpi = BrainGlobeAtlas("mpin_zfish_1um")
ref = mpi.reference

# mask for the IPN
ipn = mpi.annotation
ipn[ipn!=869] = 0

# the orientation for plotting
space_imshow = AnatomicalSpace("pil", shape=ref.shape)
space_ref = AnatomicalSpace(mpi.orientation, shape=ref.shape)
ref = space_ref.map_stack_to(space_imshow, ref)
ipn = space_ref.map_stack_to(space_imshow, ipn)

# parameters for plotting coronal, horizontal, and sagittal planes
locs = [(0,0), (1,0), (1,1)]
x = [2, 2, 1]
y = [1, 0, 0]
ref_planes = [ref.mean(i) for i in range(3)]
ipn_planes = [np.where(ipn.max(i) > 0, 1, np.nan) for i in range(3)]

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(0, 200), vertical=(-4, 170), sagittal=(-50, 200))
#bs = dict(frontal=(-1000, 2000), vertical=(-1000, 1000), sagittal=(-1000, 2000))


plotter = pltltr.AtlasPlotter(
    atlas=atlas,
    structures=["ipn", "dors_ipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
f, axs = plotter.generate_projection_plots()

In [None]:
f, ax = plt.subplots(1, 3, figsize=(14, 5))

In [None]:
X = [2, 2, 0]
Y = [1, 0, 1]

X = [2, 0]
Y = [1, 1]

x_diff = [180, 180, -270]
y_diff = [-270, -550, -560]

        
for a in ax:
    a.invert_yaxis()
    a.set(aspect="equal")
    a.axis("off")
for cell, color in zip(cell_list, colors):
    coords = cell.coords_mpi
    idx_dict = {idx: i for i, idx in enumerate(cell.nodes.index)}
    for src, dst in cell.edges.items():
        idx = [idx_dict[src], idx_dict[dst]]
        for a, x, y in zip(ax, X, Y):
            a.plot(coords[idx, x], coords[idx, y], lw=0.5, c=color)

In [None]:
f.savefig(save_dir/"em_220930_partial2_big.pdf", dpi=300)

In [None]:
f.savefig(save_dir/"em_220930_01.jpg", dpi=300)

In [None]:
f1, ax1 = plt.subplots(1, 3, figsize=(14, 5))
X = [2]
Y = [0]

x_diff = [180, 180, -270]
y_diff = [-270, -550, -560]

        
for a in ax1:
    a.invert_yaxis()
    a.set(aspect="equal")
    a.axis("off")
for cell, color in zip(cell_list, colors):
    coords = cell.coords_mpi
    idx_dict = {idx: i for i, idx in enumerate(cell.nodes.index)}
    for src, dst in cell.edges.items():
        idx = [idx_dict[src], idx_dict[dst]]
        for a, x, y in zip(ax1, X, Y):
            a.plot(coords[idx, x], coords[idx, y], lw=0.5, c=color)


In [None]:
plt.subplots_adjust(left=0.01)

In [None]:
f1.savefig(save_dir/"em_220930_2.jpg", dpi=300)