In [None]:
%load_ext autoreload
%matplotlib widget

In [None]:
%autoreload
from pathlib import Path
import numpy as np
%autoreload
from bg_atlasapi.bg_atlas import BrainGlobeAtlas
from bg_space import AnatomicalSpace
from em_reconstruction.loading import load_kzip, load_nodes
from em_reconstruction.plotting import get_mpiref_coords
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# where k.zip files are stored
data_dir = Path(r"\\funes\Shared\Hagar")

# where plots will be saved as png files
save_dir = Path(r"C:\Users\lavian\Desktop\EM plots")
save_dir.mkdir(exist_ok=True)

In [None]:
cells = {} # key: file name or cell name, value: coordinates of the cell in mpi reference

for f in data_dir.glob("*.k.zip"):
    datatype, output = load_kzip(f)
    
    # if the k.zip contains mesh files, there should be only one cell
    if datatype == "mesh":
        cells[f.stem.split(".")[0]] = get_mpiref_coords(datatype, output)
        
    # if there is no mesh, the annotation file may contain multiple cells
    elif datatype == "annotation":
        for cell in output:
            cell_id, nodes = load_nodes(cell)
            cells[cell_id] = get_mpiref_coords(datatype, nodes)

In [None]:
# load MPI reference
mpi = BrainGlobeAtlas("mpin_zfish_1um")
ref = mpi.reference

# mask for the IPN
ipn = mpi.annotation
ipn[ipn!=869] = 0

# the orientation for plotting
space_imshow = AnatomicalSpace("pil", shape=ref.shape)
space_ref = AnatomicalSpace(mpi.orientation, shape=ref.shape)
ref = space_ref.map_stack_to(space_imshow, ref)
ipn = space_ref.map_stack_to(space_imshow, ipn)

# parameters for plotting coronal, horizontal, and sagittal planes
locs = [(0,0), (1,0), (1,1)]
x = [2, 2, 1]
y = [1, 0, 0]
ref_planes = [ref.mean(i) for i in range(3)]
ipn_planes = [np.where(ipn.max(i) > 0, 1, np.nan) for i in range(3)]

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(8, 4), constrained_layout=True)
col_list = ["#107C10", "#002050", "#A80000", "#5C2D91", "#004B50", "#0078D7", "#D83B01", "#B4009E", "#01B8AA", "#F2C80F", "#8AD4EB", "#FE9666", "#A66999", "#73B761", "#CD4C46", "#71AFE2", "#8D6FD1", "#EE9E64", "#95DABB", "#4A8DDC", "#4C5D8A", "#F3C911", "#DC5B57", 
            "#33AE81", "#95C8F0", "#DD915F", "#9A64A0", "#B6B0FF", "#3049AD", "#FF994E", "#C83D95", "#FFBBED", "#42F9F9", "#00B2D9", "#FFD86C", "#009292", "#FE6DB6", "#FEB5DA", "#480091",
            "#B66DFF", "#B5DAFE", "#6DB6FF", "#4A8DDC", "#4C5D8A", "#F3C911", "#DC5B57", "#33AE81", "#95C8F0", "#DD915F", "#9A64A0",
            "#107C10", "#002050", "#A80000", "#5C2D91", "#004B50", "#0078D7", "#D83B01", "#B4009E", "#01B8AA", "#F2C80F", "#8AD4EB", "#FE9666", "#A66999", "#73B761", "#CD4C46", "#71AFE2", "#8D6FD1", "#EE9E64", "#95DABB", "#4A8DDC", "#4C5D8A", "#F3C911", "#DC5B57", 
            "#33AE81", "#95C8F0", "#DD915F", "#9A64A0", "#B6B0FF", "#3049AD", "#FF994E", "#C83D95", "#FFBBED", "#42F9F9", "#00B2D9", "#FFD86C", "#009292", "#FE6DB6", "#FEB5DA", "#480091",
            "#B66DFF", "#B5DAFE", "#6DB6FF", "#4A8DDC", "#4C5D8A", "#F3C911", "#DC5B57", "#33AE81", "#95C8F0", "#DD915F", "#9A64A0"]

for i in range(3):
        ax[i].imshow(ipn_planes[i], cmap="gray", origin="lower")
        ax[i].axis("off")
        
count = 0
for cell, coords in cells.items():
    cc = space_ref.map_points_to(space_imshow, coords)
    
    for i, (xi, yi) in enumerate(zip(x, y)):
        ax[i].scatter(cc[:,xi], cc[:,yi], s=0.01, c=col_list[count])
   
    count += 1
    
#plt.savefig(save_dir/"all_cells.eps", dpi=300)
#plt.close(fig)

In [None]:
%autoreload

from lotr import plotting as pltltr
from lotr import DATASET_LOCATION
from bg_atlasapi.core import Atlas
COLS = pltltr.COLS

atlas = Atlas(DATASET_LOCATION.parent / "anatomy" / "ipn_zfish_0.5_um_v1.6")

In [None]:
col_list = ["#107C10", "#002050", "#A80000", "#5C2D91", "#004B50", "#0078D7", "#D83B01", "#B4009E", "#01B8AA", "#F2C80F", "#8AD4EB", "#FE9666", "#A66999", "#73B761", "#CD4C46", "#71AFE2", "#8D6FD1", "#EE9E64", "#95DABB", "#4A8DDC", "#4C5D8A", "#F3C911", "#DC5B57", 
            "#33AE81", "#95C8F0", "#DD915F", "#9A64A0", "#B6B0FF", "#3049AD", "#FF994E", "#C83D95", "#FFBBED", "#42F9F9", "#00B2D9", "#FFD86C", "#009292", "#FE6DB6", "#FEB5DA", "#480091",
            "#B66DFF", "#B5DAFE", "#6DB6FF", "#4A8DDC", "#4C5D8A", "#F3C911", "#DC5B57", "#33AE81", "#95C8F0", "#DD915F", "#9A64A0",
            "#107C10", "#002050", "#A80000", "#5C2D91", "#004B50", "#0078D7", "#D83B01", "#B4009E", "#01B8AA", "#F2C80F", "#8AD4EB", "#FE9666", "#A66999", "#73B761", "#CD4C46", "#71AFE2", "#8D6FD1", "#EE9E64", "#95DABB", "#4A8DDC", "#4C5D8A", "#F3C911", "#DC5B57", 
            "#33AE81", "#95C8F0", "#DD915F", "#9A64A0", "#B6B0FF", "#3049AD", "#FF994E", "#C83D95", "#FFBBED", "#42F9F9", "#00B2D9", "#FFD86C", "#009292", "#FE6DB6", "#FEB5DA", "#480091",
            "#B66DFF", "#B5DAFE", "#6DB6FF", "#4A8DDC", "#4C5D8A", "#F3C911", "#DC5B57", "#33AE81", "#95C8F0", "#DD915F", "#9A64A0"]


In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(0, 200), vertical=(-4, 170), sagittal=(-50, 200))
#bs = dict(frontal=(-1000, 2000), vertical=(-1000, 1000), sagittal=(-1000, 2000))


plotter = pltltr.AtlasPlotter(
    atlas=atlas,
    structures=["ipn", "dors_ipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
f, axs = plotter.generate_projection_plots()

In [None]:
count = 0

x_diff = [180, 180, -270]
y_diff = [-270, -550, -560]
axes_ = [0, 2, 1]
#axes = [0, 1, 2]
for cell, coords in cells.items():
    cc = space_ref.map_points_to(space_imshow, coords)
    for i, (xi, yi) in enumerate(zip(x, y)):
        if i < 2:
            axs[i].scatter(cc[:,xi]-x_diff[i], -cc[:,yi]-y_diff[i], s=0.01, c=col_list[count])
            axs[i].scatter(cc[0,xi]-x_diff[i], -cc[0,yi]-y_diff[i], s=10, c=col_list[count])
        else:
            axs[i].scatter(-cc[:,yi]-y_diff[i], -cc[:,xi]-x_diff[i], s=0.01, c=col_list[count])
            axs[i].scatter(-cc[0,yi]-y_diff[i], -cc[0,xi]-x_diff[i], s=10, c=col_list[count])
    
    count += 1
    


In [None]:
plt.savefig(save_dir/f"{cell}__.eps", dpi=300)

In [None]:
fig = plt.figure(figsize=(8, 11), constrained_layout=True)
gs = fig.add_gridspec(2, 2, width_ratios=[ref.shape[2], ref.shape[1]], height_ratios=[ref.shape[1], ref.shape[0]])

for i, (loc, xi, yi) in enumerate(zip(locs, x, y)):
    ax.imshow(ref_planes[i], cmap="gray_r", origin="lower")
    for cell, coords in cells.items():
        cc = space_ref.map_points_to(space_imshow, coords)

        ax.scatter(cc[:,xi], cc[:,yi], s=0.01, c="r")
        ax.axis("off")
plt.savefig(save_dir/"all_cells.png")    
