# Intro

This notebook explores the usage of alpha shapes to detect layers.
Therefore, it performs multiple steps:
1. Segment tissue with a trained model
2. Segment nuclei using StarDist
3. Create an overlay of the nuclei and the tissue
4. Calculate relative size of nuclei to detected tissue
5. Apply delauney triangulation on nuclei
6. Apply alpha shapes with relative threshold depending on nuclei and tissue size (avg)

# Setup

In [1]:
# Imports 
import sys
from pathlib import Path

# Add project root to path
project_root = Path().resolve().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

In [2]:
import tifffile
import os
import matplotlib.pyplot as plt
import numpy as np

from scipy.spatial import Delaunay
from collections import defaultdict

In [3]:
# Load four images, which present a good example of layers
LAYER_PATH = project_root / "data/layer_examples"

paths = [(LAYER_PATH / image_path) for image_path in os.listdir(LAYER_PATH)]
images = list(map(tifffile.imread, paths))

In [4]:
# Normalize images
from src.utils.reinhard_normalizer import ReinhardNormalizer

normalizer = ReinhardNormalizer()
images = [normalizer.normalize(img) for img in images]

  from .autonotebook import tqdm as notebook_tqdm

objc[5369]: Class GNotificationCenterDelegate is implemented in both /opt/anaconda3/envs/research_project/lib/libgio-2.0.0.dylib (0x36d78c6d8) and /opt/anaconda3/envs/research_project/lib/python3.12/site-packages/openslide_bin/libopenslide.1.dylib (0x377095318). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.
  check_for_updates()



## 1. Segment tissue

In [5]:
# Imports
import torch
from src.models.model_loader import ModelLoader
from skimage.measure import label

In [6]:
# Specify model to load
model_loader = ModelLoader()
MODEL_CFG = "unet_2"
model = model_loader.load_cnn_model(MODEL_CFG, "unet_2c")

Loaded CNN: unet_2c


In [7]:
from src.data.preprocessing import inference_processing
from skimage.transform import resize
ORG_RES = (1920, 2560)

In [8]:
masks = []
labeled_tissues = []
device = "mps" if torch.mps.is_available() else "cpu"
print(device)
for img in images:
    img = inference_processing(img, device)

    with torch.no_grad():
        pred_logits = model(img)
        pred_mask = torch.argmax(pred_logits, dim=1).squeeze()
        pred_mask = pred_mask.cpu().numpy()
        pred_mask = resize(pred_mask, ORG_RES, anti_aliasing=True)
    labeled_tissue = label(pred_mask > 0, connectivity=2)
    labeled_tissues.append(labeled_tissue)
    masks.append(pred_mask)

mps


In [9]:
from src.utils.helpers import compare_two_images

In [71]:
# for img, mask in zip(images, masks):
#     compare_two_images(img, mask, "Normalized Image", "Predicted Mask")

### 1 b Try to apply Watershed

In [11]:
from scipy import ndimage
from skimage.segmentation import watershed
from skimage.feature import peak_local_max

def separate_touching_regions(binary_mask, min_distance=20, footprint_size=25):
    """Separate touching regions using watershed on distance transform."""
    distance = ndimage.distance_transform_edt(binary_mask)
    
    # Find local maxima in distance transform (these become markers/seeds)
    # Each "core" of a tissue region will have a local maximum
    coords = peak_local_max(
        distance, 
        min_distance=min_distance,  # Minimum distance between peaks
        footprint=np.ones((footprint_size, footprint_size)),
        labels=binary_mask
    )
    
    # Create markers from the peaks
    markers = np.zeros(distance.shape, dtype=int)
    markers[tuple(coords.T)] = np.arange(1, len(coords) + 1)
    markers = ndimage.label(markers)[0]
    
    # Apply watershed - negative distance so "valleys" become the separation lines
    labels = watershed(-distance, markers, mask=binary_mask)
    
    return labels

In [12]:
binary_masks = [(mask > 0).astype(int) for mask in masks]

In [13]:
watershed_masks = [separate_touching_regions(mask) for mask in binary_masks]

In [72]:
# for img, mask in zip(images, watershed_masks):
#     compare_two_images(img, mask, "Normalized Image", "Predicted Mask")

### 1c Try to apply Morphological Operations

In [15]:
from skimage.morphology import binary_opening, binary_closing, disk
from skimage.measure import label

def separate_with_morphology(binary_mask, open_radius=10, close_radius=3):
    """
    Separate touching regions using morphological operations.
    
    Parameters:
    - open_radius: Larger = more aggressive separation (breaks thicker connections)
    - close_radius: Fills small holes created by opening
    """
    # Opening: erosion then dilation - breaks thin connections
    opened = binary_opening(binary_mask, disk(open_radius))
    
    # Closing: dilation then erosion - fills small holes
    closed = binary_closing(opened, disk(close_radius))
    
    # Label the separated regions
    labeled = label(closed, connectivity=2)
    
    return labeled, closed



In [16]:
morph_masks = [separate_with_morphology(mask)[1] for mask in masks]

In [73]:
# for img, mask in zip(images, morph_masks):
#     compare_two_images(img, mask, "Normalized Image", "Cleaned Mask")

## 2. Segment Nuclei with Stardist

In [18]:
# Imports
from stardist.models import StarDist2D
from stardist.plot import render_label
from src.utils.helpers import cut_out_image
from skimage.exposure import rescale_intensity




In [19]:
stardist_model = StarDist2D.from_pretrained("2D_versatile_he")

Found model '2D_versatile_he' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.692478, nms_thresh=0.3.


In [20]:
nuclei_masks = []
nuclei_data_dicts = []

In [21]:
for img, mask in zip(images, masks):
    image_normed = rescale_intensity(img, out_range=(0, 1))
    labels, data_dict = stardist_model.predict_instances(image_normed, axes='YXC', prob_thresh=0.6, nms_thresh=0.0, return_labels=True)
    filtered_labels = cut_out_image(labels, mask)
    binary_labels = (filtered_labels > 0).astype(np.uint8)
    nuclei_masks.append(binary_labels)
    nuclei_data_dicts.append(data_dict)

In [74]:
# for img, mask in zip(images, nuclei_masks):
#     compare_two_images(img, mask, "Normalized Image", "Filtered nuclei mask")

## 3. Create overlay of mask and nuclei

In [23]:
def combine_masks(tissue_mask, nuclei_mask):
    """
    Combines tissue and nuclei masks into a single mask.
    
    Args:
        tissue_mask: Binary tissue mask (H, W), values 0 or 1
        nuclei_mask: Binary nuclei mask (H, W), values 0 or 1
    
    Returns:
        combined_mask: Single mask with values:
            0 = background
            1 = tissue (without nuclei)
            2 = nuclei
    """
    combined_mask = np.zeros_like(tissue_mask, dtype=np.uint8)
    
    combined_mask[tissue_mask > 0] = 1
    
    combined_mask[nuclei_mask > 0] = 2
    
    return combined_mask

In [24]:
combined_masks = [combine_masks(tissue_mask, nuclei_mask) for tissue_mask, nuclei_mask in zip(masks, nuclei_masks)]

In [75]:
# from matplotlib.colors import ListedColormap

# for image, combined_mask in zip(images, combined_masks):
#     plt.figure(figsize=(12, 6))
#     plt.subplot(1, 2, 1)
#     plt.imshow(image)
#     plt.axis("off")
#     plt.title("Normalized image")

#     plt.subplot(1, 2, 2)
#     colors = ['black', 'white', 'blue']
#     cmap = ListedColormap(colors)
#     plt.imshow(combined_mask, cmap=cmap, vmin=0, vmax=2)
#     plt.axis("off")
#     plt.title("Tissue and nuclei mask")

#     plt.show()

## 4. Calculate relative size of nuclei compared to detected tissue

In [26]:
import pandas as pd
import numpy as np
from scipy import ndimage
from skimage.measure import label, regionprops

In [27]:
def calculate_metrics_from_combined(combined_mask: np.ndarray) -> dict[str, float]:
    """
    Calculate metrics from a combined mask, including average sizes of individual regions.
    
    Args:
        combined_mask (np.ndarray): Mask with values 0 (background), 1 (tissue), 2 (nuclei)
    
    Returns:
        result(dict): dict with metrics
    """
    total_pixels = combined_mask.size
    background_pixels = np.sum(combined_mask == 0)
    tissue_pixels = np.sum(combined_mask == 1)
    nuclei_pixels = np.sum(combined_mask == 2)
    
    total_tissue_pixels = tissue_pixels + nuclei_pixels
    
    tissue_percentage = (tissue_pixels / total_pixels) * 100
    nuclei_percentage = (nuclei_pixels / total_pixels) * 100
    total_tissue_percentage = (total_tissue_pixels / total_pixels) * 100
    
    nuclei_to_tissue_ratio = (nuclei_pixels / total_tissue_pixels * 100) if total_tissue_pixels > 0 else 0
    
    nuclei_binary = (combined_mask == 2).astype(np.uint8)
    nuclei_labeled = label(nuclei_binary, connectivity=2)
    nuclei_regions = regionprops(nuclei_labeled)
    
    nuclei_sizes = [region.area for region in nuclei_regions]
    num_nuclei = len(nuclei_sizes)
    avg_nuclei_size = np.mean(nuclei_sizes) if nuclei_sizes else 0
    median_nuclei_size = np.median(nuclei_sizes) if nuclei_sizes else 0
    std_nuclei_size = np.std(nuclei_sizes) if nuclei_sizes else 0
    
    tissue_binary = (combined_mask > 0).astype(np.uint8)
    tissue_labeled = label(tissue_binary, connectivity=2)
    tissue_regions = regionprops(tissue_labeled)
    
    tissue_sizes = [region.area for region in tissue_regions]
    num_tissue_regions = len(tissue_sizes)
    avg_tissue_size = np.mean(tissue_sizes) if tissue_sizes else 0
    median_tissue_size = np.median(tissue_sizes) if tissue_sizes else 0
    std_tissue_size = np.std(tissue_sizes) if tissue_sizes else 0
    
    # Ratio of average sizes
    avg_size_ratio = (avg_nuclei_size / avg_tissue_size) if avg_tissue_size > 0 else 0
    
    result = {
        # Number of pixels
        'background_pixels': background_pixels,
        'tissue_only_pixels': tissue_pixels,
        'nuclei_pixels': nuclei_pixels,
        'total_tissue_pixels': total_tissue_pixels,
        'tissue_only_percentage': tissue_percentage,
        'nuclei_percentage': nuclei_percentage,
        'total_tissue_percentage': total_tissue_percentage,
        'nuclei_to_tissue_ratio': nuclei_to_tissue_ratio,
        
        # Size metrics
        'num_nuclei': num_nuclei,
        'avg_nuclei_size': avg_nuclei_size,
        'median_nuclei_size': median_nuclei_size,
        'std_nuclei_size': std_nuclei_size,
        'min_nuclei_size': min(nuclei_sizes) if nuclei_sizes else 0,
        'max_nuclei_size': max(nuclei_sizes) if nuclei_sizes else 0,
        
        'num_tissue_regions': num_tissue_regions,
        'avg_tissue_region_size': avg_tissue_size,
        'median_tissue_region_size': median_tissue_size,
        'std_tissue_region_size': std_tissue_size,
        'min_tissue_region_size': min(tissue_sizes) if tissue_sizes else 0,
        'max_tissue_region_size': max(tissue_sizes) if tissue_sizes else 0,
        
        'avg_nuclei_to_tissue_size_ratio': avg_size_ratio,
    }
    return result

metrics_list = []
for idx, combined in enumerate(combined_masks):
    metrics = calculate_metrics_from_combined(combined)
    metrics['image_id'] = idx
    metrics_list.append(metrics)

df = pd.DataFrame(metrics_list)

In [28]:
df.head(10)

Unnamed: 0,background_pixels,tissue_only_pixels,nuclei_pixels,total_tissue_pixels,tissue_only_percentage,nuclei_percentage,total_tissue_percentage,nuclei_to_tissue_ratio,num_nuclei,avg_nuclei_size,...,min_nuclei_size,max_nuclei_size,num_tissue_regions,avg_tissue_region_size,median_tissue_region_size,std_tissue_region_size,min_tissue_region_size,max_tissue_region_size,avg_nuclei_to_tissue_size_ratio,image_id
0,2194476,2251316,469408,2720724,45.803141,9.55013,55.353271,17.253055,997,470.820461,...,9.0,1772.0,21,129558.285714,1273.0,330071.02525,63.0,1466285.0,0.003634,0
1,4099652,664371,151177,815548,13.516663,3.075704,16.592367,18.536861,245,617.04898,...,1.0,1639.0,18,45308.222222,2966.0,121953.464164,72.0,536825.0,0.013619,1
2,4125262,631641,158297,789938,12.850769,3.220561,16.07133,20.039168,276,573.539855,...,5.0,1404.0,7,112848.285714,8322.0,190278.909224,126.0,541178.0,0.005082,2
3,2687775,1956668,270757,2227425,39.808512,5.508565,45.317078,12.155606,696,389.018678,...,87.0,1556.0,4,556856.25,5328.5,958342.753269,30.0,2216738.0,0.000699,3


## 5. Apply Alpha Shapes

In [29]:
# Create clusters
cluster_list = []

for tissue_labeled, data_dict in zip(labeled_tissues, nuclei_data_dicts):
    clusters = {}

    nucleus_centers = data_dict['points']

    for idx, (x, y) in enumerate(nucleus_centers):
        x_int, y_int = int(x), int(y)
        
        # Check bounds
        if 0 <= x_int < tissue_labeled.shape[0] and 0 <= y_int < tissue_labeled.shape[1]:
            region_id = tissue_labeled[x_int, y_int]
            
            if region_id == 0:
                continue
            
            if region_id not in clusters:
                clusters[region_id] = []
            
            clusters[region_id].append((x, y))
    cluster_list.append(clusters)

In [30]:
import alphashape
from shapely.geometry import Polygon, MultiPolygon
import numpy as np

def extract_layer(points, alpha=0.01):
    """
    Extract the outermost layer of points using alpha shapes.
    
    Args:
        points: List of (x, y) tuples
        alpha: Smaller = tighter fit, larger = looser fit
    
    Returns:
        remaining_points: Points not in the outer layer
        layer_points: Points in the outer layer
    """
    # Create alpha shape (concave hull)
    alpha_shape = alphashape.alphashape(points, alpha)
    
    # Extract boundary coordinates
    layer_coords = []
    if isinstance(alpha_shape, Polygon):
        layer_coords = list(alpha_shape.exterior.coords)
    elif isinstance(alpha_shape, MultiPolygon):
        for polygon in alpha_shape.geoms:
            layer_coords.extend(list(polygon.exterior.coords))
    
    # Remove layer points from original points
    layer_coords_set = set(layer_coords)
    remaining_points = [p for p in points if p not in layer_coords_set]
    
    return remaining_points, layer_coords

def detect_layers(cluster_points, alpha=0.01, max_layers=10):
    """
    Iteratively detect layers in a cluster of points.
    
    Returns:
        layers: {0: [(x, y), ...], 1: [(x, y), ...], ...}
    """
    layers = {}
    current_points = cluster_points.copy()
    
    for layer_id in range(max_layers):
        if len(current_points) < 4:  # Need at least 4 points for alpha shape
            break
        
        try:
            current_points, layer_coords = extract_layer(current_points, alpha)
            layers[layer_id] = layer_coords
            
            if not current_points:
                print(f"All points assigned to layers")
                break
                
        except Exception as e:
            print(f"Stopped at layer {layer_id}: {e}")
            break
    
    return layers




In [31]:
# Take a test image and visualize clusters
test_image = images[1]
test_clusters = cluster_list[1]

In [76]:
# plt.figure(figsize=(12, 12))
# for region_idx, (region_id, points) in enumerate(test_clusters.items()):
#     points_array = np.array(points)
#     plt.scatter(points_array[:, 1], points_array[:, 0], s=20, alpha=0.5, label=f'Region {region_id}')

# plt.title(f'Nuclei Clusters ({len(clusters)} regions)')
# plt.legend(loc='upper right', fontsize=8)
# plt.imshow(test_image)
# plt.axis('off')
# plt.show()

In [33]:
alpha_shapes = {}
for region_id, cluster_points in test_clusters.items():
    points_array = np.array(cluster_points)
    alpha_shapes[region_id] = alphashape.alphashape(points_array, alpha=0.01)


all_layer_coords = {}
idx = 0
for _, alpha_shape in alpha_shapes.items():
    layer_coords = []
    if isinstance(alpha_shape, Polygon):
        layer_coords = list(alpha_shape.exterior.coords)
    elif isinstance(alpha_shape, MultiPolygon):
        for polygon in alpha_shape.geoms:
            layer_coords.extend(list(polygon.exterior.coords))
    if layer_coords:
        all_layer_coords[idx] = layer_coords
        idx += 1

In [77]:
# import matplotlib.pyplot as plt

# fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# axes[0].imshow(test_image)
# region_colors = plt.cm.tab10(np.linspace(0, 1, len(test_clusters)))

# for idx, (region_id, points) in enumerate(test_clusters.items()):
#     points_array = np.array(points)
#     axes[0].scatter(points_array[:, 1], points_array[:, 0], 
#                    c=[region_colors[idx]], s=20, alpha=0.6, 
#                    label=f'Region {region_id}')

# axes[0].set_title(f'Nuclei Points ({len(test_clusters)} regions)')
# axes[0].axis('off')
# if len(test_clusters) <= 10:
#     axes[0].legend()

# axes[1].imshow(test_image)

# for idx, layer_coords in all_layer_coords.items():
#     layer_array = np.array(layer_coords)
    
#     axes[1].plot(layer_array[:, 1], layer_array[:, 0], 
#                 color=region_colors[idx % len(region_colors)], 
#                 linewidth=2.5, alpha=0.8,
#                 label=f'Region {idx}')

#     axes[1].fill(layer_array[:, 1], layer_array[:, 0], 
#                 color=region_colors[idx % len(region_colors)], 
#                 alpha=0.15)

# axes[1].set_title('Alpha Shapes (Concave Hulls)')
# axes[1].axis('off')
# if len(all_layer_coords) <= 10:
#     axes[1].legend()

# plt.tight_layout()
# plt.show()

In [35]:
test_layers = {}
for region_id, cluster_points in test_clusters.items():
    layers = detect_layers(cluster_points)
    test_layers[region_id] = layers

All points assigned to layers


In [36]:
all_image_layers = []

for img_idx, (image, clusters) in enumerate(zip(images, cluster_list)):
    print(f"Processing Image {img_idx + 1}/{len(images)}")
    
    image_layers = {}
    
    for region_id, cluster_points in clusters.items():
        
        if len(cluster_points) < 10:
            continue
        
        layers = detect_layers(cluster_points, alpha=0.01, max_layers=10)
        image_layers[region_id] = layers
    
    all_image_layers.append(image_layers)

Processing Image 1/4
Processing Image 2/4
Processing Image 3/4
Processing Image 4/4


In [78]:
# import matplotlib.pyplot as plt
# import numpy as np
# from matplotlib.colors import ListedColormap

# layer_colors = plt.cm.rainbow(np.linspace(0, 1, 10))

# for img_idx, (image, image_layers, clusters) in enumerate(zip(images, all_image_layers, cluster_list)):
#     fig, axes = plt.subplots(1, 3, figsize=(20, 7))
    
#     axes[0].imshow(image)
#     axes[0].set_title(f'Image {img_idx + 1}: Original')
#     axes[0].axis('off')
    
#     axes[1].imshow(image)
#     region_colors = plt.cm.tab10(np.linspace(0, 1, max(len(clusters), 1)))
    
#     for region_idx, (region_id, points) in enumerate(clusters.items()):
#         points_array = np.array(points)
#         axes[1].scatter(points_array[:, 1], points_array[:, 0], 
#                        c=[region_colors[region_idx % len(region_colors)]], 
#                        s=10, alpha=0.5, label=f'Region {region_id}')
    
#     axes[1].set_title(f'Nuclei Clusters ({len(clusters)} regions)')
#     axes[1].axis('off')
#     if len(clusters) <= 10:
#         axes[1].legend(loc='upper right', fontsize=8)
    
#     axes[2].imshow(image)
    
#     for region_id, layers in image_layers.items():
#         for layer_id, layer_points in layers.items():
#             if len(layer_points) > 0:
#                 layer_array = np.array(layer_points)
#                 axes[2].plot(layer_array[:, 1], layer_array[:, 0], 
#                            color=layer_colors[layer_id % len(layer_colors)], 
#                            linewidth=2.5, alpha=0.8,
#                            label=f'Layer {layer_id}' if region_id == list(image_layers.keys())[0] else '')
    
#     axes[2].set_title(f'Detected Layers (colored by depth)')
#     axes[2].axis('off')
    

#     handles, labels = axes[2].get_legend_handles_labels()
#     by_label = dict(zip(labels, handles))
#     axes[2].legend(by_label.values(), by_label.keys(), loc='upper right', fontsize=8)
    
#     plt.tight_layout()
#     plt.show()

### Try HDBSCAN for better Clustering

In [38]:
import hdbscan

def create_hdbscan_clusters(tissue_labeled: np.ndarray, data_dict: dict[str,any], cluster_size_fraction: float, min_samples_fraction) -> dict[int, list]:
    nucleus_centers = data_dict['points']

    number_of_points = len(nucleus_centers)
    number_of_tissues = len(np.unique(tissue_labeled)) - 1
    expected_per_tissue = number_of_points / max(number_of_tissues, 1)

    min_cluster_size = max(5, int(expected_per_tissue * 0.2))
    min_samples = max(3, int(min_cluster_size * min_samples_fraction))

    filtered_centers = []
    for x, y in nucleus_centers:
        x_int, y_int = int(x), int(y)
        
        if 0 <= x_int < tissue_labeled.shape[0] and 0 <= y_int < tissue_labeled.shape[1]:
            region_id = tissue_labeled[x_int, y_int]
            
            if region_id > 0:
                filtered_centers.append((x, y))
    
    if len(filtered_centers) < 2:
        cluster_list.append({})
    
    points_array = np.array(filtered_centers)
    
    clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples)
    labels = clusterer.fit_predict(points_array)

    clusters = {}
    for idx, label in enumerate(labels):
        if label == -1:
            continue
        
        if label not in clusters:
            clusters[label] = []
        
        clusters[label].append(tuple(filtered_centers[idx]))
    
    return clusters

In [39]:
cluster_fractions = np.arange(0, 1.05, 0.05) 
sample_fractions = np.arange(0, 1.05, 0.05) 
TARGET_CLUSTERS = np.array([10, 8, 2, 5])

In [40]:
import warnings
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed")

In [41]:
best_cluster_fraction = 0
best_sample_fraction = 0
best_similarity = float("inf")

for cluster_fraction in cluster_fractions:
    for sample_fraction in sample_fractions:
        cluster_list = []
        for tissue_labeled, data_dict in zip(labeled_tissues, nuclei_data_dicts):
            clusters = create_hdbscan_clusters(tissue_labeled, data_dict, cluster_fraction, sample_fraction)
            cluster_list.append(len(clusters))
        
        cluster_list = np.array(cluster_list)
        norm = np.linalg.norm(cluster_list - TARGET_CLUSTERS)
        if norm < best_similarity:
            best_similarity = norm
            best_cluster_fraction = cluster_fraction
            best_sample_fraction = sample_fraction
            print(f"Found new best values! Best cluster fraction: {best_cluster_fraction}, best sample fraction: {best_sample_fraction}")

Found new best values! Best cluster fraction: 0.0, best sample fraction: 0.0
Found new best values! Best cluster fraction: 0.0, best sample fraction: 0.15000000000000002
Found new best values! Best cluster fraction: 0.0, best sample fraction: 0.4
Found new best values! Best cluster fraction: 0.0, best sample fraction: 0.5
Found new best values! Best cluster fraction: 0.0, best sample fraction: 0.8
Found new best values! Best cluster fraction: 0.0, best sample fraction: 0.9500000000000001


In [42]:
hdb_cluster_list = []
for tissue_labeled, data_dict in zip(labeled_tissues, nuclei_data_dicts):
    clusters = create_hdbscan_clusters(tissue_labeled, data_dict, best_cluster_fraction, best_sample_fraction)
    hdb_cluster_list.append(clusters)

In [79]:
# for image, hdb_cluster, nuclei_mask in zip(images, hdb_cluster_list, nuclei_masks):
#     plt.figure(figsize=(18, 12))
#     plt.subplot(1, 2, 1)
#     plt.imshow(nuclei_mask)
#     plt.axis('off')

#     plt.subplot(1, 2, 2)
#     for region_idx, (region_id, points) in enumerate(hdb_cluster.items()):
#         points_array = np.array(points)
#         plt.scatter(points_array[:, 1], points_array[:, 0], s=10, alpha=0.7, label=f'Region {region_id}')


#     plt.title(f'Nuclei Clusters ({len(hdb_cluster)} regions)')
#     #plt.legend(loc='upper right', fontsize=8)
#     plt.imshow(image)
#     plt.axis('off')
#     plt.show()

In [44]:
hdb_all_image_layers = []

for img_idx, (image, clusters) in enumerate(zip(images, hdb_cluster_list)):
    print(f"Processing Image {img_idx + 1}/{len(images)}")
    
    image_layers = {}

    stats = df.iloc[img_idx]
    median_nuclei_size = stats["median_nuclei_size"] * 0.001
    
    for region_id, cluster_points in clusters.items():
        
        if len(cluster_points) < 10:
            continue
        
        layers = detect_layers(cluster_points, alpha=0.01, max_layers=10)
        image_layers[region_id] = layers
    
    hdb_all_image_layers.append(image_layers)

Processing Image 1/4
All points assigned to layers
All points assigned to layers
All points assigned to layers
Processing Image 2/4
Processing Image 3/4
Processing Image 4/4


In [45]:
df["median_nuclei_size"]

0    428.0
1    616.0
2    537.0
3    352.0
Name: median_nuclei_size, dtype: float64

In [80]:
# import matplotlib.pyplot as plt
# import numpy as np
# from matplotlib.colors import ListedColormap

# layer_colors = plt.cm.rainbow(np.linspace(0, 1, 10))

# for img_idx, (image, image_layers, clusters) in enumerate(zip(images, hdb_all_image_layers, hdb_cluster_list)):
#     fig, axes = plt.subplots(1, 3, figsize=(20, 7))
    
#     axes[0].imshow(image)
#     axes[0].set_title(f'Image {img_idx + 1}: Original')
#     axes[0].axis('off')
    
#     axes[1].imshow(image)
#     region_colors = plt.cm.tab10(np.linspace(0, 1, max(len(clusters), 1)))
    
#     for region_idx, (region_id, points) in enumerate(clusters.items()):
#         points_array = np.array(points)
#         axes[1].scatter(points_array[:, 1], points_array[:, 0], 
#                        c=[region_colors[region_idx % len(region_colors)]], 
#                        s=10, alpha=0.5, label=f'Region {region_id}')
    
#     axes[1].set_title(f'Nuclei Clusters ({len(clusters)} regions)')
#     axes[1].axis('off')
#     if len(clusters) <= 10:
#         axes[1].legend(loc='upper right', fontsize=8)
    
#     axes[2].imshow(image)
    
#     for region_id, layers in image_layers.items():
#         for layer_id, layer_points in layers.items():
#             if len(layer_points) > 0:
#                 layer_array = np.array(layer_points)
#                 axes[2].plot(layer_array[:, 1], layer_array[:, 0], 
#                            color=layer_colors[layer_id % len(layer_colors)], 
#                            linewidth=2.5, alpha=0.8,
#                            label=f'Layer {layer_id}' if region_id == list(image_layers.keys())[0] else '')
    
#     axes[2].set_title(f'Detected Layers (colored by depth)')
#     axes[2].axis('off')
    

#     handles, labels = axes[2].get_legend_handles_labels()
#     by_label = dict(zip(labels, handles))
#     axes[2].legend(by_label.values(), by_label.keys(), loc='upper right', fontsize=8)
    
#     plt.tight_layout()
#     plt.show()

### Iterative layer approach

In [47]:
test_data_dict = nuclei_data_dicts[1]

In [48]:
test_data_dict.keys()

dict_keys(['coord', 'points', 'prob'])

In [49]:
test_data_dict["coord"][0][0]

array([1216.    , 1218.2001, 1220.3412, 1222.0779, 1223.7898, 1225.5178,
       1226.9631, 1227.9491, 1228.8011, 1228.3591, 1227.829 , 1226.3602,
       1224.6915, 1222.7585, 1220.7545, 1218.3894, 1216.    , 1213.5497,
       1211.0656, 1209.0938, 1207.5558, 1206.4421, 1205.4302, 1205.199 ,
       1204.587 , 1204.9575, 1204.9989, 1205.871 , 1206.9973, 1208.9211,
       1211.119 , 1213.6681], dtype=float32)

In [50]:
test_data_dict["coord"][1][1]

array([2129.422 , 2129.515 , 2129.739 , 2128.9075, 2127.6082, 2125.6465,
       2123.2512, 2120.5698, 2118.    , 2115.4695, 2112.9644, 2110.8137,
       2109.136 , 2107.5898, 2106.5623, 2106.0505, 2105.2817, 2105.4067,
       2105.3987, 2106.194 , 2107.4004, 2109.3662, 2111.903 , 2115.    ,
       2118.    , 2120.8035, 2123.3523, 2125.2053, 2126.5366, 2127.4512,
       2128.4377, 2128.7634], dtype=float32)

In [51]:
len(test_data_dict["coord"][1][0])

32

In [52]:
test_points = np.array(test_data_dict["points"])
test_mask = masks[1]

filtered_points = [
    p for p in test_points
    if test_mask[int(p[0]), int(p[1])] != 0
]


print("Before:", len(test_points), "After:", len(filtered_points))

Before: 403 After: 240


In [53]:
test_points.shape

(403, 2)

In [54]:
np.array(filtered_points).shape

(240, 2)

In [55]:
test_points = np.array(filtered_points)

In [56]:
test_tri = Delaunay(test_points)

In [81]:
# plt.figure(figsize=(12, 12))
# plt.imshow(images[1])
# plt.axis("off")

In [82]:
# plt.figure(figsize=(12, 12))
# plt.imshow(images[1])
# plt.triplot(test_points[:, 1], test_points[:, 0], test_tri.simplices, 
#             color='red', linewidth=0.5, alpha=0.7)
# plt.plot(test_points[:, 1], test_points[:, 0], 'o', 
#          color='blue', markersize=3, alpha=0.8)
# plt.axis('off')
# plt.tight_layout()
# plt.show()

In [59]:
def get_delaunay_neighbors(points):
    """Build a neighbor dictionary from Delaunay triangulation."""
    points = np.array(points)
    tri = Delaunay(points)
    
    neighbors = defaultdict(set)
    
    for simplex in tri.simplices:
        for i in range(3):
            for j in range(3):
                if i != j:
                    neighbors[simplex[i]].add(simplex[j])
    
    return dict(neighbors), points

In [60]:
neighbors, points_array = get_delaunay_neighbors(test_points)

In [61]:
neighbor_indices = neighbors[0]
point_0_coords = points_array[0]
neighbor_coords = points_array[list(neighbor_indices)]

In [62]:
import networkx as nx

def build_neighbor_graph(points):
    """Build a NetworkX graph from Delaunay triangulation."""
    points = np.array(points)
    tri = Delaunay(points)
    
    G = nx.Graph()
    
    for i, (x, y) in enumerate(points):
        G.add_node(i, pos=(x, y))
    
    for simplex in tri.simplices:
        for i in range(3):
            for j in range(i+1, 3):
                p1, p2 = points[simplex[i]], points[simplex[j]]
                dist = np.linalg.norm(p1 - p2)
                G.add_edge(simplex[i], simplex[j], weight=dist)
    
    return G

G = build_neighbor_graph(test_points)

neighbors_of_5 = list(G.neighbors(5))
print("Neighbors of 5: ", neighbors_of_5)

pos = G.nodes[5]['pos']
print("Pos of node 5: ", pos)

for node in G.nodes():
    for neighbor in G.neighbors(node):
        edge_length = G[node][neighbor]['weight']

Neighbors of 5:  [18, 6, 42, 217, 113, 26, 165]
Pos of node 5:  (1542, 872)


In [63]:
def update_nuclei_data_dict(stardist_data_dict):
    """
    stardist_data_dict: whatever object you have per image; must contain:
      - 'points': nuclei centroids in same order as coord entries
      - 'coord': list/array of length N, each entry is ([x1..x32], [y1..y32])
    Returns: dict idx -> {'boundary': np.ndarray shape (32,2)}
    """
    nuclei_dict = {}
    coords_all = stardist_data_dict['coord']
    for idx, coord_pair in enumerate(coords_all):
        xs, ys = coord_pair
        boundary = np.stack([xs, ys], axis=1)
        nuclei_dict[idx] = {'boundary': boundary}
    return nuclei_dict

In [64]:
from sklearn.decomposition import PCA
from scipy.spatial.distance import cosine

def get_nucleus_orientation(boundary_points):
    """
    Extract main axis (orientation) of a nucleus from boundary points.
    Returns: unit vector representing the main axis direction
    """
    boundary_points = np.array(boundary_points)
    
    # Center the points
    centroid = boundary_points.mean(axis=0)
    centered = boundary_points - centroid
    
    # PCA to find main axis
    pca = PCA(n_components=2)
    pca.fit(centered)
    
    # First principal component is the main axis
    main_axis = pca.components_[0]
    
    return main_axis

def calculate_alignment_with_orientations(neighbors_dict, points_array, nuclei_data_dict):
    """
    Calculate alignment score based on cosine similarity of nucleus orientations.
    
    Returns: dict mapping point_index -> alignment_score
    """
    alignments = {}
    
    for point_idx in neighbors_dict:
        neighbor_indices = list(neighbors_dict[point_idx])
        
        if len(neighbor_indices) < 2:
            alignments[point_idx] = 0.0
            continue
        
        if point_idx not in nuclei_data_dict or 'boundary' not in nuclei_data_dict[point_idx]:
            alignments[point_idx] = 0.0
            continue
            
        current_orientation = get_nucleus_orientation(nuclei_data_dict[point_idx]['boundary'])
        
        similarities = []
        for neighbor_idx in neighbor_indices:
            if neighbor_idx in nuclei_data_dict and 'boundary' in nuclei_data_dict[neighbor_idx]:
                neighbor_orientation = get_nucleus_orientation(nuclei_data_dict[neighbor_idx]['boundary'])
                
                # Cosine similarity (1 - cosine distance)
                # High similarity -> good alignment
                similarity = abs(1 - cosine(current_orientation, neighbor_orientation))
                similarities.append(similarity)
        
        if similarities:
            alignments[point_idx] = np.mean(similarities)
        else:
            alignments[point_idx] = 0.0
    
    return alignments

neighbors_dict, points_array = get_delaunay_neighbors(test_points)

nuclei_data_dict = update_nuclei_data_dict(test_data_dict)

alignments = calculate_alignment_with_orientations(neighbors_dict, points_array, nuclei_data_dict)

best_start = max(alignments, key=alignments.get)
print(f"Best starting point: {best_start} with alignment score: {alignments[best_start]:.3f}")

Best starting point: 137 with alignment score: 0.977


In [65]:
best_starting_coord = points_array[best_start]
best_starting_coord

array([1328, 1170])

In [83]:
# plt.figure(figsize=(12, 12))
# plt.imshow(images[1])
# plt.plot(best_starting_coord[1], best_starting_coord[0], 'o', 
#          color='blue', markersize=5, alpha=1.0)
# plt.axis('off')
# plt.tight_layout()
# plt.show()

In [67]:
def prepare_graph(points, neighbors_dict, nuclei_data_dict):
    orientations = {}
    for i, data in nuclei_data_dict.items():
        if 'boundary' in data:
            orientations[i] = get_nucleus_orientation(data['boundary'])

    edge_sim = {}
    edge_dist = {}
    for u, nbrs in neighbors_dict.items():
        for v in nbrs:
            if u < v:
                d = np.linalg.norm(points[u] - points[v])
                edge_dist[(u, v)] = d
                if u in orientations and v in orientations:
                    s = abs(1 - cosine(orientations[u], orientations[v]))
                else:
                    s = -1
                edge_sim[(u, v)] = s
    return orientations, edge_dist, edge_sim

In [68]:
def grow_layer(start, neighbors_dict, edge_dist, edge_sim,
               sim_thresh=0.6, dist_thresh=50.0):
    visited = set([start])
    frontier = [start]
    layer = [start]

    while frontier:
        u = frontier.pop()

        candidates = []
        for v in neighbors_dict[u]:
            key = (u, v) if (u, v) in edge_sim else (v, u)
            if key not in edge_sim or key not in edge_dist:
                continue
            s = edge_sim[key]
            d = edge_dist[key]
            if s >= sim_thresh and d <= dist_thresh:
                candidates.append((s, v))

        # Keep only the top-2 by similarity
        candidates.sort(reverse=True, key=lambda x: x[0])
        for _, v in candidates[:2]:
            if v not in visited:
                visited.add(v)
                frontier.append(v)
                layer.append(v)

    return layer, visited

In [69]:
nuclei_to_explore = dict(alignments)
layers = []
orientations, edge_dist, edge_sim = prepare_graph(test_points, neighbors_dict, nuclei_data_dict)

while nuclei_to_explore:
    best_start = max(nuclei_to_explore, key=nuclei_to_explore.get)
    layer, visited = grow_layer(best_start, neighbors_dict, edge_dist, edge_sim)
    
    if len(layer) > 1:
        layers.append(layer)
    
    for idx in visited:
        nuclei_to_explore.pop(idx, None)

In [84]:
# import matplotlib.pyplot as plt
# import numpy as np

# def plot_layers(image, points, layers, point_order="xy"):
#     """
#     image: 2D or 3D array
#     points: array-like of shape (N, 2)
#     layers: list of lists of point indices
#     point_order: "xy" if points are (x, y); "yx" if points are (row, col)
#     """
#     points = np.array(points)
#     plt.figure(figsize=(10, 10))
#     h, w = image.shape[:2]
#     plt.imshow(image, cmap='gray', extent=[0, w, h, 0])
    
#     colors = plt.cm.tab20(np.linspace(0, 1, max(1, len(layers))))
    
#     for i, layer in enumerate(layers):
#         coords = points[layer]
#         if point_order == "xy":
#             xs, ys = coords[:, 0], coords[:, 1]
#         else:
#             ys, xs = coords[:, 0], coords[:, 1]
        
#         plt.scatter(xs, ys, s=20, color=colors[i], alpha=0.9, label=f"Layer {i+1}")
        
#         # Connect points in order of layer list (greedy path)
#         plt.plot(xs, ys, color=colors[i], alpha=0.6, linewidth=1)

#     plt.axis('off')
#     plt.tight_layout()
#     plt.show()

# plot_layers(images[1], test_points, layers, point_order="yx")