# CPM (Cell Phenotype Map) Montage Creator

This notebook creates montages of `cell_mask_colored.tiff` files from a Pixie output directory.

**Features:**
- Montage grid of colored cell masks
- Filename labels below each image
- Color legend showing cell types
- Customizable grid layout and styling

## 1. Setup and Imports

In [None]:
import os
from pathlib import Path
from typing import List, Optional, Tuple, Dict
import colorsys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
from skimage import io as skio
from tqdm.auto import tqdm

%matplotlib inline
plt.rcParams['figure.dpi'] = 100

## 2. Configuration

Set your input directory and parameters below.

In [None]:
# =============================================================================
# USER CONFIGURATION - Modify these parameters
# =============================================================================

# Path to Pixie cell output directory (containing cell_mask_colored/ subdirectory)
PIXIE_OUTPUT_DIR = "path/to/pixie/cell_output_dir"

# Optional: Path to cluster mapping CSV (auto-detected if None)
CLUSTER_MAPPING_FILE = None  # e.g., "path/to/cell_meta_cluster_mapping.csv"

# Output settings
OUTPUT_PATH = "CPM_montage.png"  # Set to None to only display, not save
DPI = 150

# Montage layout
NCOLS = 4                    # Number of columns in the grid
IMAGE_SIZE = 4.0             # Size of each image in inches
TITLE = "Cell Phenotype Map Montage"  # Set to None for no title

# Display options
SHOW_LEGEND = True           # Whether to show the color legend
FONT_SIZE = 10               # Font size for labels
TITLE_FONT_SIZE = 14         # Font size for title

# Optional: Specific FOVs to include (None = all FOVs)
SELECTED_FOVS = None         # e.g., ["R0C0", "R0C1", "R1C0"]

## 3. Helper Functions

In [None]:
def distinct_rgbs(n: int = 33) -> List[Tuple[float, float, float]]:
    """
    Generate n visually distinct RGB colors deterministically.
    
    Uses HSV color space with varying saturation and value for distinction.
    """
    colors = []
    for i in range(n):
        hue = i / n
        saturation = 0.7 + 0.3 * ((i % 3) / 2)
        value = 0.8 + 0.2 * ((i % 2) / 1)
        rgb = colorsys.hsv_to_rgb(hue, saturation, value)
        colors.append(rgb)
    return colors


def discover_mask_files(colored_masks_dir: Path) -> List[Path]:
    """Discover all colored mask TIFF files in the directory."""
    patterns = [
        "*_cell_mask_colored.tiff", "*_cell_mask_colored.tif",
        "*cell_mask_colored.tiff", "*cell_mask_colored.tif",
        "*.tiff", "*.tif"
    ]
    
    for pattern in patterns:
        found = list(colored_masks_dir.glob(pattern))
        if found:
            return sorted(found)
    return []


def extract_colors_from_image(image: np.ndarray) -> Dict[Tuple[int, ...], int]:
    """Extract unique colors from an RGB(A) image."""
    if image.ndim == 3:
        pixels = image.reshape(-1, image.shape[2])
    else:
        return {}
    
    unique_colors, counts = np.unique(pixels, axis=0, return_counts=True)
    return {tuple(color): count for color, count in zip(unique_colors, counts)}


def build_color_legend(
    images: List[np.ndarray],
    cluster_mapping: Optional[pd.DataFrame] = None,
) -> List[Tuple[Tuple[float, ...], str]]:
    """
    Build a color legend from the cluster mapping or by extracting colors from images.
    """
    legend_items = []
    
    if cluster_mapping is not None:
        # Get unique cluster names
        rename_col = None
        for col in ['cell_meta_cluster_rename', 'meta_cluster_rename',
                   'cluster_rename', 'cell_type', 'phenotype']:
            if col in cluster_mapping.columns:
                rename_col = col
                break
        
        if rename_col:
            unique_clusters = cluster_mapping[rename_col].unique()
            unique_clusters = [c for c in unique_clusters if pd.notna(c)]
            
            n_clusters = len(unique_clusters)
            colors = distinct_rgbs(n_clusters)
            
            for i, cluster_name in enumerate(sorted(unique_clusters)):
                legend_items.append((colors[i], str(cluster_name)))
        else:
            # Use numeric cluster IDs
            cluster_col = None
            for col in ['cell_meta_cluster', 'meta_cluster', 'cluster']:
                if col in cluster_mapping.columns:
                    cluster_col = col
                    break
            
            if cluster_col:
                unique_clusters = cluster_mapping[cluster_col].unique()
                unique_clusters = sorted([c for c in unique_clusters if pd.notna(c)])
                
                n_clusters = len(unique_clusters)
                colors = distinct_rgbs(n_clusters)
                
                for i, cluster_id in enumerate(unique_clusters):
                    legend_items.append((colors[i], f"Cluster {cluster_id}"))
    
    # If no mapping, extract colors from images
    if not legend_items and images:
        all_colors = {}
        for img in images:
            colors = extract_colors_from_image(img)
            for color, count in colors.items():
                if color in all_colors:
                    all_colors[color] += count
                else:
                    all_colors[color] = count
        
        # Filter out background (black) and sort by frequency
        filtered_colors = {
            c: count for c, count in all_colors.items()
            if sum(c[:3]) > 10
        }
        
        sorted_colors = sorted(filtered_colors.items(), key=lambda x: -x[1])[:20]
        
        for i, (color, _) in enumerate(sorted_colors):
            if max(color[:3]) > 1:
                norm_color = tuple(c / 255.0 for c in color[:3])
            else:
                norm_color = color[:3]
            legend_items.append((norm_color, f"Type {i + 1}"))
    
    return legend_items

## 4. Locate and Load Data

In [None]:
# Set up paths
pixie_output_dir = Path(PIXIE_OUTPUT_DIR)

# Locate colored masks directory
colored_masks_dir = pixie_output_dir / "cell_mask_colored"
if not colored_masks_dir.exists():
    colored_masks_dir = pixie_output_dir

if not colored_masks_dir.exists():
    raise FileNotFoundError(f"Could not find directory: {colored_masks_dir}")

print(f"Colored masks directory: {colored_masks_dir}")

# Locate cluster mapping file
if CLUSTER_MAPPING_FILE:
    cluster_mapping_path = Path(CLUSTER_MAPPING_FILE)
else:
    cluster_mapping_path = pixie_output_dir / "cell_meta_cluster_mapping.csv"

# Load cluster mapping if available
if cluster_mapping_path.exists():
    cluster_mapping = pd.read_csv(cluster_mapping_path)
    print(f"Loaded cluster mapping from: {cluster_mapping_path}")
    print(f"  Columns: {list(cluster_mapping.columns)}")
else:
    cluster_mapping = None
    print(f"No cluster mapping file found at: {cluster_mapping_path}")

In [None]:
# Discover mask files
mask_files = discover_mask_files(colored_masks_dir)

print(f"\nFound {len(mask_files)} colored mask files:")
for mf in mask_files[:10]:
    print(f"  - {mf.name}")
if len(mask_files) > 10:
    print(f"  ... and {len(mask_files) - 10} more")

In [None]:
# Filter to selected FOVs if specified
if SELECTED_FOVS:
    filtered_files = []
    for fov in SELECTED_FOVS:
        for mf in mask_files:
            if fov in mf.stem:
                filtered_files.append(mf)
                break
    mask_files = filtered_files
    print(f"Filtered to {len(mask_files)} FOVs: {SELECTED_FOVS}")

In [None]:
# Preview cluster mapping if available
if cluster_mapping is not None:
    display(cluster_mapping.head(10))

## 5. Load Images

In [None]:
# Load all images
print("Loading images...")
images = []
labels = []

for mask_file in tqdm(mask_files, desc="Loading masks"):
    img = skio.imread(mask_file)
    images.append(img)
    
    # Extract FOV name from filename
    label = mask_file.stem.replace("_cell_mask_colored", "").replace("_cell_mask", "")
    labels.append(label)

print(f"\nLoaded {len(images)} images")
if images:
    print(f"Image shape: {images[0].shape}")
    print(f"Image dtype: {images[0].dtype}")

In [None]:
# Preview a single image
if images:
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.imshow(images[0])
    ax.set_title(f"Preview: {labels[0]}")
    ax.axis('off')
    plt.show()

## 6. Build Color Legend

In [None]:
# Build legend items
legend_items = build_color_legend(images, cluster_mapping) if SHOW_LEGEND else []

print(f"Legend has {len(legend_items)} items:")
for color, label in legend_items:
    print(f"  - {label}")

In [None]:
# Preview legend colors
if legend_items:
    fig, ax = plt.subplots(figsize=(6, max(2, len(legend_items) * 0.3)))
    
    legend_patches = []
    for color, label in legend_items:
        patch = mpatches.Patch(
            facecolor=color,
            edgecolor='black',
            linewidth=0.5,
            label=label
        )
        legend_patches.append(patch)
    
    ax.legend(
        handles=legend_patches,
        loc='center',
        title='Cell Types',
        fontsize=9,
        ncol=2 if len(legend_items) > 10 else 1,
    )
    ax.axis('off')
    ax.set_title("Legend Preview")
    plt.tight_layout()
    plt.show()

## 7. Create Montage

In [None]:
def create_montage(
    images: List[np.ndarray],
    labels: List[str],
    legend_items: List[Tuple[Tuple[float, ...], str]],
    ncols: int = 4,
    figsize_per_image: Tuple[float, float] = (4, 4),
    label_height: float = 0.4,
    legend_width: float = 2.5,
    dpi: int = 150,
    title: Optional[str] = None,
    show_legend: bool = True,
    font_size: int = 10,
    title_font_size: int = 14,
) -> plt.Figure:
    """
    Create a montage of colored cell masks with labels and legend.
    """
    n_images = len(images)
    nrows = int(np.ceil(n_images / ncols))
    
    # Calculate figure dimensions
    img_width, img_height = figsize_per_image
    total_img_height = img_height + label_height
    
    has_legend = show_legend and legend_items
    fig_width = ncols * img_width + (legend_width if has_legend else 0)
    fig_height = nrows * total_img_height + (0.5 if title else 0)
    
    # Create figure
    fig = plt.figure(figsize=(fig_width, fig_height), dpi=dpi)
    
    # Set up GridSpec
    if has_legend:
        main_width_ratio = ncols * img_width / fig_width
        
        gs_main = GridSpec(
            nrows, ncols,
            left=0.02, right=main_width_ratio - 0.02,
            top=0.95 if title else 0.98, bottom=0.02,
            wspace=0.05, hspace=0.15
        )
        
        gs_legend = GridSpec(
            1, 1,
            left=main_width_ratio + 0.02, right=0.98,
            top=0.85, bottom=0.15
        )
    else:
        gs_main = GridSpec(
            nrows, ncols,
            left=0.02, right=0.98,
            top=0.95 if title else 0.98, bottom=0.02,
            wspace=0.05, hspace=0.15
        )
        gs_legend = None
    
    # Plot each image
    for idx, (img, label) in enumerate(zip(images, labels)):
        row = idx // ncols
        col = idx % ncols
        
        ax = fig.add_subplot(gs_main[row, col])
        
        # Display image
        ax.imshow(img)
        ax.axis('off')
        
        # Add label below image
        ax.text(
            0.5, -0.02,
            label,
            transform=ax.transAxes,
            ha='center', va='top',
            fontsize=font_size,
            fontweight='bold',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='white',
                     edgecolor='gray', alpha=0.9)
        )
    
    # Add legend
    if has_legend and gs_legend:
        ax_legend = fig.add_subplot(gs_legend[0, 0])
        ax_legend.axis('off')
        
        legend_patches = []
        for color, lbl in legend_items:
            patch = mpatches.Patch(
                facecolor=color,
                edgecolor='black',
                linewidth=0.5,
                label=lbl
            )
            legend_patches.append(patch)
        
        legend = ax_legend.legend(
            handles=legend_patches,
            loc='center left',
            title='Cell Types',
            title_fontsize=font_size + 1,
            fontsize=font_size - 1,
            frameon=True,
            fancybox=True,
            shadow=False,
            ncol=1 if len(legend_items) <= 15 else 2,
        )
        legend.get_frame().set_edgecolor('gray')
        legend.get_frame().set_linewidth(1)
    
    # Add title
    if title:
        fig.suptitle(title, fontsize=title_font_size, fontweight='bold', y=0.98)
    
    plt.tight_layout()
    return fig

In [None]:
# Create the montage
print("Creating montage...")

fig = create_montage(
    images=images,
    labels=labels,
    legend_items=legend_items,
    ncols=NCOLS,
    figsize_per_image=(IMAGE_SIZE, IMAGE_SIZE),
    dpi=DPI,
    title=TITLE,
    show_legend=SHOW_LEGEND,
    font_size=FONT_SIZE,
    title_font_size=TITLE_FONT_SIZE,
)

plt.show()

## 8. Save Montage

In [None]:
# Save the montage
if OUTPUT_PATH:
    # Recreate figure for saving (in case display altered it)
    fig = create_montage(
        images=images,
        labels=labels,
        legend_items=legend_items,
        ncols=NCOLS,
        figsize_per_image=(IMAGE_SIZE, IMAGE_SIZE),
        dpi=DPI,
        title=TITLE,
        show_legend=SHOW_LEGEND,
        font_size=FONT_SIZE,
        title_font_size=TITLE_FONT_SIZE,
    )
    
    fig.savefig(
        OUTPUT_PATH,
        dpi=DPI,
        bbox_inches='tight',
        facecolor='white',
        edgecolor='none'
    )
    plt.close(fig)
    
    print(f"Saved montage to: {OUTPUT_PATH}")
else:
    print("OUTPUT_PATH not set - montage not saved")

---
## 9. Custom Legend (Optional)

If the auto-detected legend doesn't match your data, you can specify a custom color-to-label mapping.

In [None]:
# Example: Define custom color mapping
# Uncomment and modify as needed

# CUSTOM_COLOR_MAPPING = {
#     "T cells": (0.8, 0.2, 0.2),       # Red
#     "B cells": (0.2, 0.2, 0.8),       # Blue
#     "Macrophages": (0.2, 0.8, 0.2),   # Green
#     "Tumor cells": (0.8, 0.8, 0.2),   # Yellow
#     "Fibroblasts": (0.8, 0.2, 0.8),   # Magenta
#     "Endothelial": (0.2, 0.8, 0.8),   # Cyan
# }

CUSTOM_COLOR_MAPPING = None  # Set to dict above to use custom colors

In [None]:
# Create montage with custom legend if specified
if CUSTOM_COLOR_MAPPING:
    custom_legend_items = [(color, label) for label, color in CUSTOM_COLOR_MAPPING.items()]
    
    fig = create_montage(
        images=images,
        labels=labels,
        legend_items=custom_legend_items,
        ncols=NCOLS,
        figsize_per_image=(IMAGE_SIZE, IMAGE_SIZE),
        dpi=DPI,
        title=TITLE,
        show_legend=True,
        font_size=FONT_SIZE,
        title_font_size=TITLE_FONT_SIZE,
    )
    
    plt.show()
    
    # Save with custom legend
    if OUTPUT_PATH:
        custom_output = OUTPUT_PATH.replace(".png", "_custom_legend.png")
        fig.savefig(custom_output, dpi=DPI, bbox_inches='tight', facecolor='white')
        print(f"Saved: {custom_output}")
else:
    print("No custom color mapping defined - skipping")

---
## 10. Create Subset Montages (Optional)

Create montages for specific subsets of FOVs.

In [None]:
# List all available FOVs
print("Available FOVs:")
for i, label in enumerate(labels):
    print(f"  {i+1:3d}. {label}")

In [None]:
# Example: Create a montage with only the first 6 FOVs
# Uncomment to run

# subset_indices = [0, 1, 2, 3, 4, 5]  # First 6 FOVs
# subset_images = [images[i] for i in subset_indices]
# subset_labels = [labels[i] for i in subset_indices]

# fig = create_montage(
#     images=subset_images,
#     labels=subset_labels,
#     legend_items=legend_items,
#     ncols=3,
#     figsize_per_image=(5, 5),
#     dpi=DPI,
#     title="Selected FOVs",
#     show_legend=SHOW_LEGEND,
# )
# plt.show()

---
## 11. Summary

In [None]:
print("="*50)
print("CPM MONTAGE SUMMARY")
print("="*50)
print(f"\nInput directory: {PIXIE_OUTPUT_DIR}")
print(f"Images loaded: {len(images)}")
print(f"Legend items: {len(legend_items)}")
print(f"Grid layout: {NCOLS} columns Ã— {int(np.ceil(len(images)/NCOLS))} rows")
if OUTPUT_PATH:
    print(f"\nOutput saved to: {OUTPUT_PATH}")
print("\nDone!")