In [1]:
# Imports 
import sys
from pathlib import Path

# Add project root to path
project_root = Path().resolve().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

In [2]:
import tifffile
import hdbscan
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import networkx as nx

from scipy.spatial import Delaunay
from src.models.model_loader import ModelLoader
from src.utils.helpers import compare_two_images
from skimage.measure import label
from scipy.spatial import Delaunay
from collections import defaultdict
from stardist.models import StarDist2D
from stardist.plot import render_label
from src.utils.helpers import cut_out_image
from skimage.exposure import rescale_intensity
from sklearn.decomposition import PCA
from scipy.spatial.distance import cosine
from math import acos, degrees
from sklearn.cluster import DBSCAN

bioimageio_utils.py (2): pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.


In [3]:
# Load four images, which present a good example of layers
LAYER_PATH = project_root / "data/layer_examples"

paths = [(LAYER_PATH / image_path) for image_path in os.listdir(LAYER_PATH)]
images = list(map(tifffile.imread, paths))

In [72]:
# Normalize images
from src.utils.reinhard_normalizer import ReinhardNormalizer

normalizer = ReinhardNormalizer()
images = [normalizer.normalize(img) for img in images]

# Create Tissue Masks

In [5]:
# Specify model to load
model_loader = ModelLoader()
MODEL_CFG = "unet_2"
model = model_loader.load_cnn_model(MODEL_CFG, "unet_2c")

Loaded CNN: unet_2c


In [6]:
from src.data.preprocessing import inference_processing
from skimage.transform import resize
ORG_RES = (1920, 2560)

In [7]:
masks = []
labeled_tissues = []
device = "mps" if torch.mps.is_available() else "cpu"
print(device)
for img in images:
    img = inference_processing(img, device)

    with torch.no_grad():
        pred_logits = model(img)
        pred_mask = torch.argmax(pred_logits, dim=1).squeeze()
        pred_mask = pred_mask.cpu().numpy()
        pred_mask = resize(pred_mask, ORG_RES, anti_aliasing=True)
    labeled_tissue = label(pred_mask > 0, connectivity=2)
    labeled_tissues.append(labeled_tissue)
    masks.append(pred_mask)

mps


In [84]:
# for img, mask in zip(images, masks):
#     compare_two_images(img, mask, "Normalized Image", "Predicted Mask")

# Segment Nuclei

In [8]:
stardist_model = StarDist2D.from_pretrained("2D_versatile_he")

Found model '2D_versatile_he' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.692478, nms_thresh=0.3.


In [9]:
nuclei_masks = []
nuclei_data_dicts = []

for img, mask in zip(images, masks):
    image_normed = rescale_intensity(img, out_range=(0, 1))
    labels, data_dict = stardist_model.predict_instances(image_normed, axes='YXC', prob_thresh=0.25, nms_thresh=0.01, return_labels=True)
    filtered_labels = cut_out_image(labels, mask)
    binary_labels = (filtered_labels > 0).astype(np.uint8)
    nuclei_masks.append(binary_labels)
    nuclei_data_dicts.append(data_dict)

In [87]:
# for img, mask in zip(images, nuclei_masks):
#     compare_two_images(img, mask, "Normalized Image", "Filtered nuclei mask")

# Clean Segmentation

In [10]:
def poly_area(x,y):
    return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)))

In [11]:
def calculate_median_area(coordinates: np.ndarray) -> float:
    areas = []
    for coord in coordinates:
        area = poly_area(np.array(coord[0]), np.array(coord[1]))
        areas.append(area)
    
    median_area = np.median(np.array(areas))
    return median_area

In [12]:
def filter_data_dict_mask(mask: np.ndarray, data_dict: dict[str, any]) -> list:
    points = data_dict["points"]
    filtered_points = []
    binary_mask = (mask > 0).astype(int)
    for point, coord in zip(points, data_dict["coord"]):
        x, y = int(point[0]), int(point[1])
        if binary_mask[x, y] == 1:
            filtered_points.append([point[0], point[1]])

    filtered_data_dict = dict(data_dict)
    filtered_data_dict["points"] = np.array(filtered_points)
    
    return filtered_data_dict

In [13]:
def filter_data_dict(mask: np.ndarray, data_dict: dict[str, any], area_th: float = 0.5) -> list:
    points = data_dict["points"]
    median_area = calculate_median_area(data_dict["coord"])
    filtered_points = []
    filtered_coords = []
    filtered_probs = []

    binary_mask = (mask > 0).astype(int)
    for i, (point, coord) in enumerate(zip(points, data_dict["coord"])):
        x, y = int(point[0]), int(point[1])
        area = poly_area(np.array(coord[0]), np.array(coord[1]))
        if binary_mask[x, y] == 1 and area > area_th * median_area:
            filtered_points.append([point[0], point[1]])
            filtered_coords.append(coord)
            filtered_probs.append(data_dict["prob"][i])

    filtered_data_dict = dict(data_dict)
    filtered_data_dict["points"] = np.array(filtered_points)
    filtered_data_dict["coord"] = np.array(filtered_coords)
    filtered_data_dict["prob"] = np.array(filtered_probs)
    
    return filtered_data_dict

In [14]:
def plot_image_and_points(image: np.ndarray, data_dict: dict[str,np.ndarray]) -> None:
    points = data_dict["points"]
    height, width = image.shape[:2]
    
    plt.figure(figsize=(12, 12))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.axis("off")
    
    plt.subplot(1, 2, 2)
    n_points = len(points)
    random_values = np.random.rand(n_points)
    plt.scatter(points[:, 1], points[:, 0], s=5, c=random_values, cmap='tab20')
    plt.xlim(0, width)
    plt.ylim(height, 0)
    plt.axis("off")
    plt.gca().set_aspect('equal')
    plt.show()

In [15]:
test_data_dict = nuclei_data_dicts[1]
len(test_data_dict["points"])

595

In [16]:
filter_test_data_dict = filter_data_dict(masks[1], test_data_dict)
len(filter_test_data_dict["points"])

330

In [17]:
mask_data_dict = filter_data_dict_mask(masks[1], test_data_dict)
len(mask_data_dict["points"])

359

In [18]:
# for image, mask, data_dict in zip(images, masks, nuclei_data_dicts):
#     filtered_data_dict = filter_data_dict(mask, data_dict)
#     plot_image_and_points(image, filtered_data_dict)

In [19]:
def plot_image_with_points(image: np.ndarray, data_dict: dict[str,np.ndarray], ax=None) -> None:
    if ax is None:
        plt.figure(figsize=(12, 12))
        show_plot = True
    else:
        show_plot = False
    
    points = data_dict["points"]
    plt.imshow(image)
    n_points = len(points)
    random_values = np.random.rand(n_points)
    plt.scatter(points[:, 1], points[:, 0], 
             c=random_values, cmap='tab20', s=15, alpha=1)
    plt.axis('off')
    plt.tight_layout()

    if show_plot:
        plt.show()

In [20]:
def compare_images_with_points(image_1: np.ndarray, data_dict_1: dict[str,np.ndarray], image_2: np.ndarray, data_dict_2: dict[str, np.ndarray]) -> None:
    plt.figure(figsize=(18, 12))

    plt.subplot(1, 2, 1)
    points_1 = data_dict_1["points"]
    plt.imshow(image_1)
    plt.plot(points_1[:, 1], points_1[:, 0], 'o', 
             color='blue', markersize=3, alpha=0.8)
    plt.axis('off')

    plt.subplot(1, 2, 2)
    points_2 = data_dict_2["points"]
    plt.imshow(image_2)
    plt.plot(points_2[:, 1], points_2[:, 0], 'o', 
             color='blue', markersize=3, alpha=0.8)
    plt.axis('off')


    plt.tight_layout()
    plt.show()

In [99]:
# for image, mask, data_dict in zip(images, masks, nuclei_data_dicts):
#     filtered_data_dict = filter_data_dict(mask, data_dict)
#     plot_image_with_points(image, filtered_data_dict)

In [100]:
# # Compare
# for image, mask, data_dict in zip(images, masks, nuclei_data_dicts):
#     filtered_data_dict = filter_data_dict(mask, data_dict)
#     compare_images_with_points(image, filtered_data_dict, image, data_dict)

In [21]:
def recompute_nuclei_centers(data_dict: dict[str,any]) -> dict[str,any]:
    points = data_dict["points"]
    corrected_points = []
    coords = data_dict["coord"]
    for i, coord in enumerate(coords):
        mean_x = np.round(np.mean(coord[0]))
        mean_y = np.round(np.mean(coord[1]))
        corrected_points.append([mean_x, mean_y])
    
    recentered_dict = dict(data_dict)
    recentered_dict["points"] = np.array(corrected_points)

    return recentered_dict

In [102]:
# # Recompute nuclei and compare
# for image, mask, data_dict in zip(images, masks, nuclei_data_dicts):
#     filtered_data_dict = filter_data_dict(mask, data_dict)
#     centered_data_dict = recompute_nuclei_centers(filtered_data_dict)
#     compare_images_with_points(image, filtered_data_dict, image, centered_data_dict)

# Graph Layer Detection

In [22]:
# Find neighbors

def get_delaunay_neighbors(points):
    points = np.array(points)
    tri = Delaunay(points)
    
    neighbors = defaultdict(set)
    
    for simplex in tri.simplices:
        for i in range(3):
            for j in range(3):
                if i != j:
                    neighbors[simplex[i]].add(simplex[j])
    
    return dict(neighbors), points

In [23]:
def get_nucleus_orientation(boundary_points):
    boundary_points = np.array(boundary_points)
    
    if boundary_points.shape[0] == 2:
        boundary_points = boundary_points.T

    # Center the points
    centroid = boundary_points.mean(axis=0)
    centered = boundary_points - centroid
    
    # PCA to find main axis
    pca = PCA(n_components=2)
    pca.fit(centered)
    
    # First principal component is the main axis
    main_axis = pca.components_[0]
    
    return main_axis

In [24]:
def calculate_alignment_similarity(main_axis: np.ndarray, compared_axis: np.ndarray) -> float:
    norm_main = np.linalg.norm(main_axis)
    norm_compared = np.linalg.norm(compared_axis)
    
    if norm_main == 0 or norm_compared == 0:
        return 0.0
    
    cosine_similarity = np.dot(main_axis, compared_axis) / (norm_main * norm_compared)
    return np.abs(cosine_similarity)

In [25]:
def calculate_alignment_angle(main_point: tuple[int, int], compared_point: tuple[int, int], main_axis: np.ndarray) -> int:
    direction_vector = np.array([
        compared_point[0] - main_point[0],
        compared_point[1] - main_point[1]
    ])
    
    direction_norm = np.linalg.norm(direction_vector)
    if direction_norm == 0:
        return 0.0
    
    direction_vector = direction_vector / direction_norm
    main_axis = main_axis / np.linalg.norm(main_axis)
    
    cos_angle = np.clip(np.dot(direction_vector, main_axis), -1.0, 1.0)
    
    # Calculate angle in radians, then convert to degrees
    angle_rad = acos(cos_angle)
    angle_deg = degrees(angle_rad)
    
    if angle_deg > 90:
        angle_deg = 180 - angle_deg
    angle_deg = 90 - angle_deg
    
    return np.round(angle_deg)

In [26]:
def get_distance(main_point: np.ndarray, compared_point: np.ndarray) -> float:
    return np.linalg.norm(main_point - compared_point)

In [27]:
def calculate_similarity_score(alignment: float, angle: float, weights: np.ndarray = np.array([0.5, 0.5])) -> float:
    alignment_score = alignment
    angle_score = 1.0 - (angle / 90.0)
    
    similarity = np.sum(np.array([alignment_score, angle_score]) * weights)
    
    return similarity

In [28]:
def get_median_distance(points: np.ndarray, neighbor_dict: dict[int, set[int]]) -> float:
    distances = []
    for i, point in enumerate(points):
        for neighbor in neighbor_dict[i]:
            if i < neighbor:
                neighbor_point = points[neighbor]
                distances.append(get_distance(point, neighbor_point))
    
    if len(distances) == 0:
        return 0.0
    
    return np.median(distances)

In [29]:
def build_neighbor_graph(points: np.ndarray, neighbor_dict: dict[int, set[int]], boundary_points: np.ndarray, distance_threshold: float) -> nx.Graph:
    G = nx.Graph()

    for i, (x, y) in enumerate(points):
        G.add_node(i, pos=(x, y))
    
    for i, point in enumerate(points):
        main_boundary = boundary_points[i]
        main_axis = get_nucleus_orientation(main_boundary)

        for neighbor in neighbor_dict[i]:
            neighbor_point = points[neighbor]
            distance = get_distance(point, neighbor_point)

            if distance <= distance_threshold:
                neighbor_boundary = boundary_points[neighbor]
                neighbor_axis = get_nucleus_orientation(neighbor_boundary)
                
                alignment = calculate_alignment_similarity(main_axis, neighbor_axis)
                angle = calculate_alignment_angle(point, neighbor_point, main_axis)
                

                G.add_edge(i, neighbor, distance=distance, alignment=alignment, angle=angle)

    for i in G.nodes():
        neighbors = list(G.neighbors(i))
        if len(neighbors) == 0:
            G.nodes[i]['best_similarity'] = 0.0
            continue
        
        neighbor_scores = []
        for neighbor in neighbors:
            edge_data = G[i][neighbor]
            similarity = calculate_similarity_score(
                edge_data['alignment'],
                edge_data['angle']
            )
            neighbor_scores.append((neighbor, similarity))
        
        # Get top 2 neighbors
        neighbor_scores.sort(key=lambda x: x[1], reverse=True)
        top_two = neighbor_scores[:2]
        
        # Calculate combined similarity score
        if len(top_two) == 2:
            best_similarity = (top_two[0][1] + top_two[1][1]) / 2.0
        elif len(top_two) == 1:
            best_similarity = top_two[0][1]
        else:
            best_similarity = 0.0
        
        G.nodes[i]['best_similarity'] = best_similarity

    return G

In [30]:
def filter_neighbor_graph(G: nx.Graph, n1: int, n2: int, alignment_threshold: float, angle_threshold: float, distance_threshold: float | None = None) -> bool:
    edge = G[n1][n2]

    alignment = edge["alignment"]
    angle = edge["angle"]
    distance = edge["distance"]

    alignment_ok = alignment >= alignment_threshold
    angle_ok = angle <= angle_threshold
    distance_ok = (distance_threshold is None) or (distance <= distance_threshold)

    return alignment_ok and angle_ok and distance_ok

In [31]:
def filter_graph_top_n(G: nx.Graph, n: int = 2) -> nx.Graph:
    G_filtered = G.copy()

    for node in G.nodes():
        neighbors = list(G.neighbors(node))
        if len(neighbors) <= n:
            continue
        
        neighbor_scores = []
        for neighbor in neighbors:
            edge_data = G[node][neighbor]
            similarity = calculate_similarity_score(
                edge_data['alignment'],
                edge_data['angle']
            )
            neighbor_scores.append((neighbor, similarity))
        neighbor_scores.sort(key=lambda x: x[1], reverse=True)
        top_n_neighbors = {neighbor for neighbor, _ in neighbor_scores[:n]}

        for neighbor in neighbors:
            if neighbor not in top_n_neighbors and G_filtered.has_edge(node, neighbor):
                G_filtered.remove_edge(node, neighbor)
        
    return G_filtered

In [32]:
def visualize_graph_overlay(image: np.ndarray, filtered_graph: nx.Graph, 
                           node_size: int = 50, edge_width: float = 2.0,
                           node_color: str = 'blue', edge_color: str = 'cyan',
                           alpha: float = 0.7, ax=None) -> None:
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 12))
        show_plot = True
    else:
        show_plot = False
    
    # Display image
    ax.imshow(image)
    ax.axis('off')

    components = list(nx.connected_components(filtered_graph))
    n_components = len(components)
    cmap = plt.cm.get_cmap('tab20')
    colors = [cmap(i / max(n_components - 1, 1)) for i in range(n_components)]

    node_to_color = {}
    for i, component in enumerate(components):
        for node in component:
            node_to_color[node] = colors[i]

    # Draw edges
    for (n1, n2) in filtered_graph.edges():
        pos1 = filtered_graph.nodes[n1]['pos']
        pos2 = filtered_graph.nodes[n2]['pos']
        edge_color = node_to_color[n1]
        ax.plot([pos1[1], pos2[1]], [pos1[0], pos2[0]], 
                color=edge_color, linewidth=edge_width, alpha=alpha)
    
    # Draw nodes
    for node in filtered_graph.nodes():
        pos = filtered_graph.nodes[node]['pos']
        node_color = node_to_color[node]
        ax.scatter(pos[1], pos[0], s=node_size, c=[node_color], 
                  linewidths=1, alpha=alpha, zorder=5)
    
    plt.tight_layout()

    if show_plot:
        plt.show()

### Plot created graph

In [1]:
# for i, data_dict in enumerate(nuclei_data_dicts):
#     filtered_data_dict = filter_data_dict(masks[i], data_dict)
#     points = filtered_data_dict["points"]
#     neighbor_dict, _ = get_delaunay_neighbors(points)
#     boundary_points = filtered_data_dict["coord"]
#     dist_threshold = get_median_distance(points, neighbor_dict) * 2
#     graph = build_neighbor_graph(points=points, neighbor_dict=neighbor_dict, boundary_points=boundary_points, distance_threshold=dist_threshold)
#     visualize_graph_overlay(images[i], graph)

### Plot filtered graph

In [76]:
# for i, data_dict in enumerate(nuclei_data_dicts):
#     filtered_data_dict = filter_data_dict(masks[i], data_dict)
#     points = filtered_data_dict["points"]
#     neighbor_dict, _ = get_delaunay_neighbors(points)
#     boundary_points = filtered_data_dict["coord"]
#     dist_threshold = get_median_distance(points, neighbor_dict) * 1.5
#     graph = build_neighbor_graph(points=points, neighbor_dict=neighbor_dict, boundary_points=boundary_points, distance_threshold=dist_threshold)
#     filtered_graph = nx.subgraph_view(
#         graph, 
#         filter_edge=lambda n1, n2: filter_neighbor_graph(
#             graph, n1, n2, 
#             alignment_threshold=0.6, 
#             angle_threshold=45.0,
#             distance_threshold=None
#         )
#     )
#     visualize_graph_overlay(images[i], filtered_graph)

### Plot filtered graph with only top two connections

In [73]:
# for i, data_dict in enumerate(nuclei_data_dicts):
#     filtered_data_dict = filter_data_dict(masks[i], data_dict)
#     points = filtered_data_dict["points"]
#     neighbor_dict, _ = get_delaunay_neighbors(points)
#     boundary_points = filtered_data_dict["coord"]
#     dist_threshold = get_median_distance(points, neighbor_dict) * 1.5
#     graph = build_neighbor_graph(points=points, neighbor_dict=neighbor_dict, boundary_points=boundary_points, distance_threshold=dist_threshold)
#     top_2_graph = filter_graph_top_n(graph, 2)
#     filtered_graph = nx.subgraph_view(
#         top_2_graph, 
#         filter_edge=lambda n1, n2: filter_neighbor_graph(
#             graph, n1, n2, 
#             alignment_threshold=0.6, 
#             angle_threshold=45.0,
#             distance_threshold=None
#         )
#     )
#     visualize_graph_overlay(images[i], filtered_graph)

### Plot the best_similarity of the nuclei

In [None]:
def visualize_nodes_by_similarity(image: np.ndarray, filtered_graph: nx.Graph,
                                  node_size: int = 50, alpha: float = 0.7, ax=None) -> None:
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 12))
        show_plot = True
    else:
        show_plot = False
    
    ax.imshow(image)
    ax.axis('off')
    
    similarities = []
    positions = []
    
    for node in filtered_graph.nodes():
        similarity = filtered_graph.nodes[node].get('best_similarity', 0.0)
        pos = filtered_graph.nodes[node]['pos']
        similarities.append(similarity)
        positions.append([pos[0], pos[1]])
    
    if len(similarities) == 0:
        print("No nodes with best_similarity attribute found")
        plt.tight_layout()
        plt.show()
        return
    
    similarities = np.array(similarities)
    positions = np.array(positions)
    
    # Create blue-to-red colormap
    cmap = plt.cm.get_cmap('viridis')
    
    # Plot all nodes at once for better performance
    scatter = ax.scatter(positions[:, 1], positions[:, 0], 
                        s=node_size, c=similarities, 
                        cmap=cmap, alpha=alpha, zorder=5, vmin=0, vmax=1)
    
    plt.tight_layout()

    if show_plot:
        plt.show()

In [75]:
# for i, data_dict in enumerate(nuclei_data_dicts):
#     filtered_data_dict = filter_data_dict(masks[i], data_dict)
#     points = filtered_data_dict["points"]
#     neighbor_dict, _ = get_delaunay_neighbors(points)
#     boundary_points = filtered_data_dict["coord"]
#     dist_threshold = get_median_distance(points, neighbor_dict) * 1.5
#     graph = build_neighbor_graph(points=points, neighbor_dict=neighbor_dict, boundary_points=boundary_points, distance_threshold=dist_threshold)
#     visualize_nodes_by_similarity(images[i], graph)

### Plot axis of nuclei

In [48]:
def get_axis_for_nuclei(boundary_points: np.ndarray) -> np.ndarray:
    all_axises = []
    for boundary in boundary_points:
        axis = get_nucleus_orientation(boundary)
        all_axises.append(axis)

    return np.array(all_axises)

In [49]:
def visualize_nuclei_axes(image: np.ndarray, points: np.ndarray, axes: np.ndarray,
                          line_length: float = 25.0, point_size: int = 30,
                          line_color: str = 'green', point_color: str = 'blue',
                          alpha: float = 0.8, linewidth: float = 2.0, ax=None) -> None:
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 12))
        show_plot = True
    else:
        show_plot = False
    
    ax.imshow(image)
    ax.axis('off')
    
    for point, axis in zip(points, axes):
        # Calculate line endpoints (extending in both directions)
        start_point = point - axis * line_length
        end_point = point + axis * line_length
        
        # Draw line through the point
        ax.plot([start_point[1], end_point[1]], 
               [start_point[0], end_point[0]],
               color=line_color, linewidth=linewidth, alpha=alpha, zorder=4)
    
    ax.scatter(points[:, 1], points[:, 0], s=point_size, c=point_color,
              alpha=0.5, zorder=5)
    
    plt.tight_layout()

    if show_plot:
        plt.show()

In [75]:
# for i, data_dict in enumerate(nuclei_data_dicts):
#     boundary_points = data_dict["coord"]
#     points = data_dict["points"]
#     all_axises = get_axis_for_nuclei(boundary_points)
#     visualize_nuclei_axes(images[i], points, all_axises)

In [51]:
def visualize_graph_overlay_with_axes(image: np.ndarray, filtered_graph: nx.Graph, filtered_data_dict: dict[str, np.ndarray],
                           node_size: int = 50, edge_width: float = 2.0,
                           node_color: str = 'blue', edge_color: str = 'cyan',
                           alpha: float = 0.5) -> None:
    fig, ax = plt.subplots(figsize=(12, 12))
    
    # Display image
    ax.imshow(image)
    ax.axis('off')

    points = filtered_data_dict["points"]
    boundary_points = filtered_data_dict["coord"]
    axes = get_axis_for_nuclei(boundary_points)
    
    # Draw edges
    for (n1, n2) in filtered_graph.edges():
        pos1 = filtered_graph.nodes[n1]['pos']
        pos2 = filtered_graph.nodes[n2]['pos']
        ax.plot([pos1[1], pos2[1]], [pos1[0], pos2[0]], 
                color=edge_color, linewidth=edge_width, alpha=alpha)
    
    # Draw nodes
    for node in filtered_graph.nodes():
        pos = filtered_graph.nodes[node]['pos']
        ax.scatter(pos[1], pos[0], s=node_size, c=node_color, 
                  edgecolors='black', linewidths=1, alpha=alpha, zorder=5)

    # Draw axes
    for point, axis in zip(points, axes):
        start_point = point - axis * 25
        end_point = point + axis * 25
        
        ax.plot([start_point[1], end_point[1]], 
               [start_point[0], end_point[0]],
               color="red", linewidth=2, alpha=alpha, zorder=4)
    
    plt.tight_layout()
    plt.show()

In [52]:
# for i, data_dict in enumerate(nuclei_data_dicts):
#     filtered_data_dict = filter_data_dict(masks[i], data_dict)
#     points = filtered_data_dict["points"]
#     neighbor_dict, _ = get_delaunay_neighbors(points)
#     boundary_points = filtered_data_dict["coord"]
#     dist_threshold = get_median_distance(points, neighbor_dict) * 1.5
#     graph = build_neighbor_graph(points=points, neighbor_dict=neighbor_dict, boundary_points=boundary_points, distance_threshold=dist_threshold)
#     filtered_graph = nx.subgraph_view(
#         graph, 
#         filter_edge=lambda n1, n2: filter_neighbor_graph(
#             graph, n1, n2, 
#             alignment_threshold=0.7, 
#             angle_threshold=45.0,
#             distance_threshold=None
#         )
#     )
#     visualize_graph_overlay_with_axes(images[i], filtered_graph, filtered_data_dict)

In [53]:
# for i, data_dict in enumerate(nuclei_data_dicts):
#     filtered_data_dict = filter_data_dict(masks[i], data_dict)
#     points = filtered_data_dict["points"]
#     neighbor_dict, _ = get_delaunay_neighbors(points)
#     boundary_points = filtered_data_dict["coord"]
#     dist_threshold = get_median_distance(points, neighbor_dict) * 1.5
#     graph = build_neighbor_graph(points=points, neighbor_dict=neighbor_dict, boundary_points=boundary_points, distance_threshold=dist_threshold)
#     filtered_graph = nx.subgraph_view(
#         graph, 
#         filter_edge=lambda n1, n2: filter_neighbor_graph(
#             graph, n1, n2, 
#             alignment_threshold=0.6, 
#             angle_threshold=45.0,
#             distance_threshold=None
#         )
#     )

#     fig, axes = plt.subplots(1, 3, figsize=(18, 12))


#     all_axises = get_axis_for_nuclei(boundary_points)
#     visualize_nuclei_axes(images[i], points, all_axises, ax=axes[0], line_length=25, point_size=10, linewidth=1.0)
#     axes[0].set_title("Segmented Nuclei with Axes")
    
#     visualize_nodes_by_similarity(images[i], filtered_graph, ax=axes[1], node_size=10)
#     axes[1].set_title("Nuclei with Similarity")

#     visualize_graph_overlay(images[i], filtered_graph, ax=axes[2], node_size=10, edge_width=1.0)
#     axes[2].set_title("Graph Overlay")
    
#     plt.tight_layout()
#     plt.show()

### Find unorganized regions with HDBClustering

In [56]:
def detect_unorganized_regions(data_dict: dict[str,any], distance_threshold: float = 0.4, min_samples_fraction: float = 0.01) -> dict[int, list]:
    points = data_dict['points']
    neighbor_dict, _ = get_delaunay_neighbors(points)
    median_distance = get_median_distance(points, neighbor_dict)
    num_points = len(points)
    clusterer = DBSCAN(eps=distance_threshold * median_distance, min_samples=int(min_samples_fraction * num_points))
    labels = clusterer.fit_predict(points)
    clusters = {}
    for i, label in enumerate(labels):
        if label == -1:
            continue
        if label not in clusters:
            clusters[label] = []
        clusters[label].append(points[i])
    return clusters

In [59]:
# for image, data_dict, nuclei_mask in zip(images, nuclei_data_dicts, nuclei_masks):
#     filtered_data_dict = filter_data_dict(nuclei_mask, data_dict)
#     hdb_cluster = detect_unorganized_regions(filtered_data_dict)
#     plt.figure(figsize=(18, 12))
#     plt.subplot(1, 2, 1)
#     plt.imshow(nuclei_mask)
#     plt.axis('off')

#     plt.subplot(1, 2, 2)
#     for region_idx, (region_id, points) in enumerate(hdb_cluster.items()):
#         points_array = np.array(points)
#         plt.scatter(points_array[:, 1], points_array[:, 0], s=10, alpha=0.7, label=f'Region {region_id}')


#     plt.title(f'Nuclei Clusters ({len(hdb_cluster)} regions)')
#     #plt.legend(loc='upper right', fontsize=8)
#     plt.imshow(image)
#     plt.axis('off')
#     plt.show()

In [60]:
def get_median_similarity(filtered_graph: nx.Graph) -> float:
    similarities = []
    for node in filtered_graph.nodes():
        similarity = filtered_graph.nodes[node].get("best_similarity", 0.0)
        similarities.append(similarity)
    return np.median(similarities)

In [69]:
def visualize_classification(
    image: np.ndarray, 
    filtered_graph: nx.Graph, 
    classifications: dict[int, str],
    node_size: int = 50, 
    alpha: float = 0.7, 
    ax=None
) -> None:
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 12))
        show_plot = True
    else:
        show_plot = False
    
    ax.imshow(image)
    ax.axis('off')
    
    organized_pos = []
    unorganized_pos = []
    
    for node in filtered_graph.nodes():
        pos = filtered_graph.nodes[node]['pos']
        classification = classifications.get(node, 'unorganized')
        
        if classification == 'organized':
            organized_pos.append([pos[0], pos[1]])
        else:
            unorganized_pos.append([pos[0], pos[1]])
    
    # Plot organized nuclei in green
    if len(organized_pos) > 0:
        organized_pos = np.array(organized_pos)
        ax.scatter(organized_pos[:, 1], organized_pos[:, 0], 
                  s=node_size, c='green', alpha=alpha, 
                  label=f"Organized", zorder=5)
    
    # Plot unorganized nuclei in red
    if len(unorganized_pos) > 0:
        unorganized_pos = np.array(unorganized_pos)
        ax.scatter(unorganized_pos[:, 1], unorganized_pos[:, 0], 
                  s=node_size, c='red', alpha=alpha, 
                  label=f"Unorganized", zorder=5)
    
    ax.legend(loc='upper right')
    
    if show_plot:
        plt.tight_layout()
        plt.show()


In [None]:
def classify_nuclei_weighted(
    filtered_graph: nx.Graph,
    k_neighbors: int = 5,
    use_second_order: bool = True,
    threshold: float = None,
    distance_weight: bool = True
) -> dict[int, str]:
    if threshold is None:
        threshold = get_median_similarity(filtered_graph)
    
    classifications = {}
    
    for node in filtered_graph.nodes():
        neighbors_with_dist = []
        
        for neighbor in filtered_graph.neighbors(node):
            edge_data = filtered_graph[node][neighbor]
            distance = edge_data.get('distance', 1.0)
            neighbors_with_dist.append((neighbor, distance))
        
        neighbors_with_dist.sort(key=lambda x: x[1])
        k_nearest = neighbors_with_dist[:k_neighbors]
        
        weighted_similarities = []
        total_weight = 0.0
        
        for neighbor, distance in k_nearest:
            sim = filtered_graph.nodes[neighbor].get('best_similarity', 0.0)
            
            if distance_weight and distance > 0:
                # Weight by inverse distance
                weight = 1.0 / (distance)
            else:
                weight = 1.0
            
            weighted_similarities.append(sim * weight)
            total_weight += weight

        if use_second_order and len(k_nearest) > 0:
            second_order_neighbors = {}
            
            for neighbor, _ in k_nearest:
                for second_neighbor in filtered_graph.neighbors(neighbor):
                    if second_neighbor not in second_order_neighbors and second_neighbor != node:
                        
                        if second_neighbor not in [n for n, d in k_nearest]:
                            try:
                                edge_data = filtered_graph[neighbor][second_neighbor]
                                dist = edge_data.get('distance', 1.0)
                                second_order_neighbors[second_neighbor] = dist
                            except:
                                pass
            
            for neighbor, distance in second_order_neighbors.items():
                sim = filtered_graph.nodes[neighbor].get('best_similarity', 0.0)
                
                # Reduce weight for second-order neighbors
                if distance_weight and distance > 0:
                    weight = 0.25 / (distance + 1e-6)
                else:
                    weight = 0.25
                
                weighted_similarities.append(sim * weight)
                total_weight += weight
        
        if total_weight > 0:
            avg_weighted_similarity = sum(weighted_similarities) / total_weight
        else:
            avg_weighted_similarity = filtered_graph.nodes[node].get('best_similarity', 0.0)
        
        if avg_weighted_similarity > threshold:
            classifications[node] = "organized"
        else:
            classifications[node] = "unorganized"
    
    return classifications


In [76]:
# for i, (image, data_dict, nuclei_mask) in enumerate(zip(images, nuclei_data_dicts, nuclei_masks)):
#     filtered_data_dict = filter_data_dict(masks[i], data_dict)
#     points = filtered_data_dict["points"]
#     neighbor_dict, _ = get_delaunay_neighbors(points)
#     boundary_points = filtered_data_dict["coord"]
#     dist_threshold = get_median_distance(points, neighbor_dict) * 1.5
#     graph = build_neighbor_graph(points=points, neighbor_dict=neighbor_dict, boundary_points=boundary_points, distance_threshold=dist_threshold)
#     filtered_graph = nx.subgraph_view(
#         graph, 
#         filter_edge=lambda n1, n2: filter_neighbor_graph(
#             graph, n1, n2, 
#             alignment_threshold=0.6, 
#             angle_threshold=45.0,
#             distance_threshold=None
#         )
#     )

#     classifications_weighted = classify_nuclei_weighted(
#         filtered_graph=filtered_graph,
#         k_neighbors=5,
#         use_second_order=True,
#         threshold=None,
#         distance_weight=True
#     )
#     visualize_classification(image, filtered_graph, classifications_weighted)
