In [12]:
import os
from typing import List, Dict, Tuple, Any
import time
import json

import numpy as np
from dotenv import load_dotenv
import nrrd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

from src.core.utils import fs
from src.flatmap.utils.bridge import get_boundaries, get_isocortex_3d_projector, get_average_template_projection
from src.flatmap.utils.read_ni_output import get_marker_points_from_xml
from src.flatmap.ccf_streamlines.morphology import transform_coordinates_to_volume
from src.flatmap.utils.nb_utils import get_projection_images_from_dict, setup_main_plot, plot_boundaries
from src.flatmap.utils.nb_utils import expand_label_image, dilate_image, dilate_binary_image

from src.flatmap.ccf_streamlines.projection import IsocortexCoordinateProjector
from src.flatmap import external_data_registry as data_registry

In [2]:
if load_dotenv() is False:
    print("Failed to load .env file")

EXPERIMENT_DATA_DIR = os.getenv("CASE_FILE_DATA_DIR")
if EXPERIMENT_DATA_DIR is None:
    print("Failed to load CASE_FILE_DATA_DIR from .env file")

example_case_file_dir = os.path.join(EXPERIMENT_DATA_DIR, "working")
case_files = fs.find_in_dir_with_ext(example_case_file_dir, ".xml")
if len(case_files) == 0:
    print("No case files found in directory")

CASE_FILE_PATH = case_files[0]
ALLEN_DATA_DIR = os.getenv("ALLEN_ATLAS_FILES")

if ALLEN_DATA_DIR is None:
    print("Failed to load ALLEN_ATLAS_FILES from .env file")

ALLEN_ANNOTATION_PATH = os.path.join(ALLEN_DATA_DIR, "annotation_10.nrrd")

if os.path.exists(ALLEN_ANNOTATION_PATH) is False: \
        print(f"Failed to find allen annotation file: {ALLEN_ANNOTATION_PATH}")

In [4]:
print(f"Case file path: {CASE_FILE_PATH}")
print(f"Allen annotation path: {ALLEN_ANNOTATION_PATH}")

Case file path: D:/data/neuro/flatmap/experiments\working\2000_ChAT_rabies 6cases Cortex.xml
Allen annotation path: D:/data/neuro/Allen\annotation_10.nrrd


In [13]:
view_space_for_other_hemisphere = "flatmap_butterfly"
view_lookup_file = "flatmap_butterfly"

surface_paths_file = data_registry.get_streamline_path("surface_paths_10_v3")
closest_surface_voxel_reference_file = data_registry.get_streamline_path("closest_surface_voxel_lookup")
streamline_layer_thickness_file = data_registry.get_isocortex_metric_path("cortical_layers_10_v2")
view_lookup_file = data_registry.get_view_lookup_file_path(view_lookup_file)
layer_depth_file = data_registry.get_isocortex_metric_path("avg_layer_depths")

with open(layer_depth_file, "r") as f:
    layer_tops = json.load(f)

layer_thicknesses = {
    "Isocortex layer 1": layer_tops["2/3"],
    "Isocortex layer 2/3": layer_tops["4"] - layer_tops["2/3"],
    "Isocortex layer 4": layer_tops["5"] - layer_tops["4"],
    "Isocortex layer 5": layer_tops["6a"] - layer_tops["5"],
    "Isocortex layer 6a": layer_tops["6b"] - layer_tops["6a"],
    "Isocortex layer 6b": layer_tops["wm"] - layer_tops["6b"],
}

proj_butterfly_slab = IsocortexCoordinateProjector(
    surface_paths_file=surface_paths_file,
    closest_surface_voxel_reference_file=closest_surface_voxel_reference_file,
    layer_thicknesses=None,
    streamline_layer_thickness_file=streamline_layer_thickness_file,
    resolution=(10, 10, 10),
    projection_file=view_lookup_file)

In [6]:
bf_left_boundaries, bf_right_boundaries = get_boundaries(view_lookup_file=view_lookup_file,
                                                         view_space_for_other_hemisphere=view_space_for_other_hemisphere)
bf_projection_max = get_average_template_projection()

loading path information


100%|██████████| 1016/1016 [00:12<00:00, 80.95it/s] 


In [7]:
marker_points = get_marker_points_from_xml(CASE_FILE_PATH)

In [19]:
morphological_list = []
names = marker_points.keys()
factor_values = (-10, -10, 10)
scale: str = "voxels"
thickness_type: str = "unnormalized"
# thickness_type: str = "normalized_full"
# thickness_type: str = "normalized_layers"
hemisphere: str = "both"
view_space_for_other_hemisphere: bool = False,
drop_voxels_outside_view_streamlines: bool = False

for i, points in enumerate(marker_points.values()):
    tic = time.perf_counter()

    points_vals_copy = points.copy()

    points_vals_copy = np.floor(points_vals_copy / factor_values).astype(int)
    points_vals = points_vals_copy.copy()
    points_vals[:, 0] = points_vals_copy[:, 2]
    points_vals[:, 1] = points_vals_copy[:, 1]
    points_vals[:, 2] = points_vals_copy[:, 0]

    morph_layers= proj_butterfly_slab.project_coordinates(
        coords=points_vals,
        scale=scale,
        thickness_type=thickness_type,
        hemisphere = hemisphere,
        view_space_for_other_hemisphere=view_space_for_other_hemisphere,
        drop_voxels_outside_view_streamlines=drop_voxels_outside_view_streamlines)

    morphological_list.append(morph_layers)
    toc = time.perf_counter()
    print(f"Finished in {toc - tic:0.4f} seconds")

loading path information


100%|██████████| 1/1 [00:00<00:00, 1000.31it/s]


loading path information


100%|██████████| 1/1 [00:00<00:00, 1000.31it/s]

Finished in 2.1654 seconds





loading path information


100%|██████████| 1/1 [00:00<00:00, 499.92it/s]


loading path information


100%|██████████| 1/1 [00:00<00:00, 999.83it/s]

Finished in 2.0618 seconds





loading path information


100%|██████████| 1/1 [00:00<00:00, 1166.70it/s]


loading path information


100%|██████████| 1/1 [00:00<00:00, 1011.65it/s]

Finished in 2.0961 seconds





loading path information


100%|██████████| 1/1 [00:00<00:00, 500.22it/s]


loading path information


100%|██████████| 1/1 [00:00<00:00, 1000.07it/s]

Finished in 2.0751 seconds





loading path information


100%|██████████| 1/1 [00:00<00:00, 500.33it/s]


loading path information


100%|██████████| 1/1 [00:00<00:00, 1000.55it/s]

Finished in 2.0752 seconds





loading path information


100%|██████████| 1/1 [00:00<00:00, 999.12it/s]


loading path information


100%|██████████| 1/1 [00:00<00:00, 999.36it/s]

Finished in 2.0528 seconds





In [17]:
def create_label_images(image_list: List[np.ndarray]):
    label_image = None
    for i, image in enumerate(image_list):
        main_max = image.max(axis=2).T
        if label_image is None:
            label_image = np.zeros_like(main_max)
        # main_max = dilate_image_binary(main_max, radius= 3)
        label_image[main_max > 0] += i + 1
        print(np.count_nonzero(main_max))

    return label_image


def create_top_left_images(image_list: List[np.ndarray]):
    top_image = None
    left_image = None
    for i, image in enumerate(image_list):
        top_max = image.max(axis=1).T
        left_max = image.max(axis=0)
        if top_image is None:
            top_image = np.zeros_like(top_max)
            left_image = np.zeros_like(left_max)
        top_image[top_max > 0] = i + 1
        left_image[left_max > 0] = i + 1

    return {"left": left_image,"top": top_image}

In [18]:
%%time
# morphological_list = morphological_dict["unnormalized"] # unnormalized, normalized_layers, normalized_full
label_image = create_label_images(morphological_list)
top_left_images = create_top_left_images(morphological_list)
top_image = np.fliplr(top_left_images["top"])
left_image = top_left_images["left"]

plt.style.use('dark_background')
label_cmap = ListedColormap(['black', 'green', 'red', 'blue', 'orange', 'yellow', 'purple', 'pink', 'cyan', 'brown', 'white'])

fig, axes = setup_main_plot()
axes = plot_boundaries(axes, bf_left_boundaries, bf_right_boundaries)
color_map_template = "Greys_r"
color_map = "Dark2_r"

dilated_image = label_image.copy()
mid_point = dilated_image.shape[1] // 2
end_point = dilated_image.shape[1]
dilated_image[:, 0:mid_point] = expand_label_image(dilated_image[:, 0:mid_point], 6)
dilated_image[:, mid_point:end_point] = expand_label_image(dilated_image[:, mid_point:end_point], 3)
dilated_image = np.fliplr(dilated_image)

axes[1, 1].imshow(bf_projection_max.T, cmap=color_map_template, alpha=1.0, interpolation=None)
ax = axes[1, 1].imshow(dilated_image, cmap=label_cmap, alpha=0.5, interpolation=None)

axes[0, 1].imshow(top_image, cmap=label_cmap, alpha=1.0, interpolation=None)
axes[1, 0].imshow(left_image, cmap=label_cmap, alpha=1.0, interpolation=None)

AxisError: axis 2 is out of bounds for array of dimension 2