In [2]:
# ALWAYS RUN THIS FIRST!
import os
import sys
from pathlib import Path

NOTEBOOK_DIR = Path("/rsrch9/home/plm/idso_fa1_pathology/codes/yshokrollahi/vitamin-p-latest")
os.chdir(NOTEBOOK_DIR)
sys.path.insert(0, str(NOTEBOOK_DIR))

print(f"‚úÖ Working directory: {os.getcwd()}")

‚úÖ Working directory: /rsrch9/home/plm/idso_fa1_pathology/codes/yshokrollahi/vitamin-p-latest


In [9]:
import os
import glob
import numpy as np
import zarr
import pandas as pd
import cv2
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
import numcodecs
from scipy import ndimage

# --- CONFIGURATION ---
BASE_PATH = Path("/rsrch9/home/plm/idso_fa1_pathology/TIER1/yasin-vitaminp/public-datasets/TNBC")
OUTPUT_BASE = Path("/rsrch9/home/plm/idso_fa1_pathology/TIER2/yasin-vitaminp/tnbc/zarr_data")

# Split configuration
TRAIN_SLIDES = [1, 2, 3, 4, 5, 6, 7]  # 7 patients for training
VAL_SLIDES = [8, 9]                    # 2 patients for validation
TEST_SLIDES = [10, 11]                 # 2 patients for testing

PATCH_SIZE = 512
NUM_WORKERS = 16
cv2.setNumThreads(0)

# ---------------------------------------------------------------------
# 1. Convert Binary Mask to Instance Mask
# ---------------------------------------------------------------------
def binary_to_instance_mask(binary_mask):
    """
    Convert binary mask (0=background, 255=nuclei) to instance mask.
    Uses connected components to assign unique ID to each nucleus.
    """
    # Ensure binary format
    binary = (binary_mask > 127).astype(np.uint8)
    
    # Label connected components (each nucleus gets unique ID)
    instance_mask, num_instances = ndimage.label(binary)
    
    return instance_mask.astype(np.int32)

# ---------------------------------------------------------------------
# 2. Padding Logic
# ---------------------------------------------------------------------
def pad_to_512_multiple(img):
    """Pad image/mask to be evenly divisible by 512"""
    h, w = img.shape[:2]
    target_h = max(int(np.ceil(h / PATCH_SIZE) * PATCH_SIZE), PATCH_SIZE)
    target_w = max(int(np.ceil(w / PATCH_SIZE) * PATCH_SIZE), PATCH_SIZE)
    
    pad_h, pad_w = target_h - h, target_w - w
    if pad_h == 0 and pad_w == 0:
        return img
    
    # Handle both 2D (mask) and 3D (image) arrays
    if img.ndim == 3:
        padding = ((0, pad_h), (0, pad_w), (0, 0))
    else:
        padding = ((0, pad_h), (0, pad_w))
    
    return np.pad(img, padding, mode='constant', constant_values=0)

# ---------------------------------------------------------------------
# 3. Worker Function - Process Single Image Pair
# ---------------------------------------------------------------------
def process_image_pair(args):
    """
    Process one image+mask pair:
    1. Load image and GT mask
    2. Convert binary mask to instance mask
    3. Pad to 512 multiples
    4. Extract 512x512 patches
    5. Return patches (will be aggregated per slide later)
    """
    img_path, gt_path, patch_name = args
    
    try:
        # --- A. Load Image ---
        img = cv2.imread(str(img_path))
        if img is None:
            print(f"‚ö†Ô∏è  Cannot read image: {img_path}")
            return None
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h_img, w_img = img.shape[:2]
        
        # --- B. Load GT Mask ---
        gt_mask = cv2.imread(str(gt_path), cv2.IMREAD_GRAYSCALE)
        if gt_mask is None:
            print(f"‚ö†Ô∏è  Cannot read GT mask: {gt_path}")
            return None
        
        # --- C. Convert Binary to Instance Mask ---
        instance_mask = binary_to_instance_mask(gt_mask)
        
        # --- D. Pad to 512 Multiples ---
        img_padded = pad_to_512_multiple(img)
        mask_padded = pad_to_512_multiple(instance_mask)
        h_pad, w_pad = img_padded.shape[:2]
        
        # --- E. Extract 512x512 Patches ---
        img_stack, mask_stack, metadata_list = [], [], []
        
        for y in range(0, h_pad, PATCH_SIZE):
            for x in range(0, w_pad, PATCH_SIZE):
                crop_img = img_padded[y:y+PATCH_SIZE, x:x+PATCH_SIZE]
                crop_mask = mask_padded[y:y+PATCH_SIZE, x:x+PATCH_SIZE]
                
                # Ensure full 512x512 patch
                if crop_img.shape[0] == PATCH_SIZE and crop_img.shape[1] == PATCH_SIZE:
                    img_stack.append(crop_img)
                    mask_stack.append(crop_mask)
                    metadata_list.append({
                        'original_file': patch_name,
                        'x': x,
                        'y': y,
                        'original_height': h_img,
                        'original_width': w_img
                    })
        
        if not img_stack:
            print(f"‚ö†Ô∏è  No patches extracted for {patch_name}")
            return None
        
        return {
            'patch_name': patch_name,
            'images': np.stack(img_stack, axis=0),
            'masks': np.stack(mask_stack, axis=0),
            'metadata': metadata_list
        }
    
    except Exception as e:
        print(f"‚ùå Error processing {patch_name}: {e}")
        import traceback
        traceback.print_exc()
        return None

# ---------------------------------------------------------------------
# 4. Process Single Slide and Determine Split
# ---------------------------------------------------------------------
def process_slide(slide_num):
    """
    Process all images for a single slide (e.g., Slide_01 with GT_01).
    Aggregates all patches from that slide into single zarr arrays.
    Determines which split (train/val/test) this slide belongs to.
    """
    slide_name = f"Slide_{slide_num:02d}"
    gt_name = f"GT_{slide_num:02d}"
    
    slide_dir = BASE_PATH / slide_name
    gt_dir = BASE_PATH / gt_name
    
    if not slide_dir.exists() or not gt_dir.exists():
        print(f"‚ö†Ô∏è  Skipping {slide_name}: directory not found")
        return 0, 0
    
    # Determine which split this slide belongs to
    if slide_num in TRAIN_SLIDES:
        split_name = "tnbc_train"
    elif slide_num in VAL_SLIDES:
        split_name = "tnbc_val"
    elif slide_num in TEST_SLIDES:
        split_name = "tnbc_test"
    else:
        print(f"‚ö†Ô∏è  Slide {slide_num} not assigned to any split!")
        return 0, 0
    
    # Find all image files in the slide
    img_files = sorted(slide_dir.glob("*.png"))
    
    if not img_files:
        print(f"‚ö†Ô∏è  No images found in {slide_name}")
        return 0, 0
    
    print(f"\nüìÇ Processing {slide_name} ({len(img_files)} images) -> {split_name}")
    
    # Build task list for this slide
    tasks = []
    for img_path in img_files:
        patch_name = img_path.name  # e.g., "01_1.png"
        gt_path = gt_dir / patch_name
        
        if gt_path.exists():
            tasks.append((img_path, gt_path, patch_name))
        else:
            print(f"   ‚ö†Ô∏è  No GT mask for {patch_name}")
    
    if not tasks:
        return 0, 0
    
    # Process all image pairs for this slide
    results = []
    for task in tqdm(tasks, desc=f"  {slide_name}", leave=False):
        result = process_image_pair(task)
        if result is not None:
            results.append(result)
    
    if not results:
        print(f"‚ö†Ô∏è  No valid results for {slide_name}")
        return 0, 0
    
    # --- Aggregate All Patches for This Slide ---
    all_images = []
    all_masks = []
    all_metadata = []
    
    for result in results:
        all_images.append(result['images'])
        all_masks.append(result['masks'])
        all_metadata.extend(result['metadata'])
    
    # Concatenate all patches
    final_images = np.concatenate(all_images, axis=0)
    final_masks = np.concatenate(all_masks, axis=0)
    
    # --- Save to Zarr in appropriate split directory ---
    split_dir = OUTPUT_BASE / split_name
    slide_out_path = split_dir / slide_name
    os.makedirs(slide_out_path, exist_ok=True)
    compressor = numcodecs.Blosc(cname='zstd', clevel=3)
    
    # Save images
    z_img = zarr.open_array(
        str(slide_out_path / 'images.zarr'),
        mode='w',
        shape=final_images.shape,
        chunks=(1, 512, 512, 3),
        dtype='uint8',
        compressor=compressor
    )
    z_img[:] = final_images
    
    # Save instance masks
    z_mask = zarr.open_array(
        str(slide_out_path / 'nuclei_masks.zarr'),
        mode='w',
        shape=final_masks.shape,
        chunks=(1, 512, 512),
        dtype='int32',
        compressor=compressor
    )
    z_mask[:] = final_masks
    
    # Save metadata
    pd.DataFrame(all_metadata).to_csv(
        slide_out_path / 'metadata.csv',
        index=False
    )
    
    # Get instance stats
    unique_instances = np.unique(final_masks)
    num_instances = len(unique_instances[unique_instances > 0])
    
    print(f"   ‚úÖ {slide_name}: {len(final_images)} patches, {num_instances} nuclei instances")
    
    return len(final_images), num_instances

# ---------------------------------------------------------------------
# 5. Main Processing Pipeline
# ---------------------------------------------------------------------
def main():
    print("=" * 70)
    print("üî¨ TNBC Dataset ‚Üí Zarr with Train/Val/Test Split")
    print("=" * 70)
    
    # Find all available slides
    slide_dirs = sorted(BASE_PATH.glob("Slide_*"))
    slide_numbers = [int(d.name.split('_')[1]) for d in slide_dirs]
    
    print(f"\nüìä Found {len(slide_numbers)} slides: {slide_numbers}")
    print(f"\nüìÇ Split Configuration:")
    print(f"   Training:   Slides {TRAIN_SLIDES} -> tnbc_train/")
    print(f"   Validation: Slides {VAL_SLIDES} -> tnbc_val/")
    print(f"   Testing:    Slides {TEST_SLIDES} -> tnbc_test/")
    print(f"\n   Patch size: {PATCH_SIZE}x{PATCH_SIZE}")
    print(f"   Output: {OUTPUT_BASE}")
    
    # Create output directories
    OUTPUT_BASE.mkdir(parents=True, exist_ok=True)
    for split_name in ['tnbc_train', 'tnbc_val', 'tnbc_test']:
        (OUTPUT_BASE / split_name).mkdir(exist_ok=True)
    
    # Process each slide sequentially
    total_patches = 0
    total_instances = 0
    split_stats = {'tnbc_train': {'patches': 0, 'nuclei': 0, 'slides': []},
                   'tnbc_val': {'patches': 0, 'nuclei': 0, 'slides': []},
                   'tnbc_test': {'patches': 0, 'nuclei': 0, 'slides': []}}
    
    for slide_num in slide_numbers:
        num_patches, num_instances = process_slide(slide_num)
        total_patches += num_patches
        total_instances += num_instances
        
        # Track per-split stats
        if slide_num in TRAIN_SLIDES:
            split_stats['tnbc_train']['patches'] += num_patches
            split_stats['tnbc_train']['nuclei'] += num_instances
            split_stats['tnbc_train']['slides'].append(slide_num)
        elif slide_num in VAL_SLIDES:
            split_stats['tnbc_val']['patches'] += num_patches
            split_stats['tnbc_val']['nuclei'] += num_instances
            split_stats['tnbc_val']['slides'].append(slide_num)
        elif slide_num in TEST_SLIDES:
            split_stats['tnbc_test']['patches'] += num_patches
            split_stats['tnbc_test']['nuclei'] += num_instances
            split_stats['tnbc_test']['slides'].append(slide_num)
    
    # --- Summary Statistics ---
    print("\n" + "=" * 70)
    print("‚úÖ PROCESSING COMPLETE")
    print("=" * 70)
    print(f"Total slides processed: {len(slide_numbers)}")
    print(f"Total patches created: {total_patches}")
    print(f"Total nuclei instances: {total_instances}")
    
    print("\n" + "‚îÄ" * 70)
    print("üìä Per-Split Statistics:")
    print("‚îÄ" * 70)
    for split_name in ['tnbc_train', 'tnbc_val', 'tnbc_test']:
        stats = split_stats[split_name]
        print(f"\n{split_name.upper()}:")
        print(f"   Slides: {stats['slides']}")
        print(f"   Total patches: {stats['patches']}")
        print(f"   Total nuclei: {stats['nuclei']}")
        if stats['patches'] > 0:
            print(f"   Avg nuclei/patch: {stats['nuclei']/stats['patches']:.1f}")
    
    print("\n" + "=" * 70)
    print(f"Output directory: {OUTPUT_BASE}")
    print("=" * 70)
    
    # --- Print Directory Structure ---
    print("\nüìä Output Structure:")
    for split_name in ['tnbc_train', 'tnbc_val', 'tnbc_test']:
        split_dir = OUTPUT_BASE / split_name
        if split_dir.exists():
            slides = sorted([d.name for d in split_dir.iterdir() if d.is_dir()])
            print(f"   {split_name}/")
            for slide in slides:
                print(f"      ‚îî‚îÄ‚îÄ {slide}/")
    
    # Create split info file
    create_split_info_file(split_stats)

def create_split_info_file(split_stats):
    """Create a text file documenting the split."""
    info_file = OUTPUT_BASE / "split_info.txt"
    
    with open(info_file, 'w') as f:
        f.write("TNBC Dataset Split Information\n")
        f.write("=" * 70 + "\n\n")
        f.write("Directory Structure:\n")
        f.write(f"  tnbc_train/   - Slides {TRAIN_SLIDES}\n")
        f.write(f"  tnbc_val/     - Slides {VAL_SLIDES}\n")
        f.write(f"  tnbc_test/    - Slides {TEST_SLIDES}\n\n")
        f.write("Split Strategy: Patient-level split to prevent data leakage\n")
        f.write("Total: 11 patients -> 7 train / 2 val / 2 test\n\n")
        
        f.write("Statistics:\n")
        f.write("-" * 70 + "\n")
        for split_name in ['tnbc_train', 'tnbc_val', 'tnbc_test']:
            stats = split_stats[split_name]
            f.write(f"\n{split_name.upper()}:\n")
            f.write(f"  Slides: {stats['slides']}\n")
            f.write(f"  Total patches: {stats['patches']}\n")
            f.write(f"  Total nuclei: {stats['nuclei']}\n")
            if stats['patches'] > 0:
                f.write(f"  Avg nuclei/patch: {stats['nuclei']/stats['patches']:.1f}\n")
    
    print(f"\nüìÑ Split info saved to: {info_file}")

if __name__ == "__main__":
    main()

üî¨ TNBC Dataset ‚Üí Zarr with Train/Val/Test Split

üìä Found 11 slides: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

üìÇ Split Configuration:
   Training:   Slides [1, 2, 3, 4, 5, 6, 7] -> tnbc_train/
   Validation: Slides [8, 9] -> tnbc_val/
   Testing:    Slides [10, 11] -> tnbc_test/

   Patch size: 512x512
   Output: /rsrch9/home/plm/idso_fa1_pathology/TIER2/yasin-vitaminp/tnbc/zarr_data

üìÇ Processing Slide_01 (7 images) -> tnbc_train


                                                         

   ‚úÖ Slide_01: 7 patches, 156 nuclei instances

üìÇ Processing Slide_02 (3 images) -> tnbc_train


                                                         

   ‚úÖ Slide_02: 3 patches, 97 nuclei instances

üìÇ Processing Slide_03 (5 images) -> tnbc_train


                                                         

   ‚úÖ Slide_03: 5 patches, 103 nuclei instances

üìÇ Processing Slide_04 (8 images) -> tnbc_train


                                                         

   ‚úÖ Slide_04: 8 patches, 187 nuclei instances

üìÇ Processing Slide_05 (4 images) -> tnbc_train


                                                         

   ‚úÖ Slide_05: 4 patches, 150 nuclei instances

üìÇ Processing Slide_06 (3 images) -> tnbc_train


                                                         

   ‚úÖ Slide_06: 3 patches, 97 nuclei instances

üìÇ Processing Slide_07 (3 images) -> tnbc_train


                                                         

   ‚úÖ Slide_07: 3 patches, 298 nuclei instances

üìÇ Processing Slide_08 (4 images) -> tnbc_val


                                                         

   ‚úÖ Slide_08: 4 patches, 133 nuclei instances

üìÇ Processing Slide_09 (6 images) -> tnbc_val


                                                         

   ‚úÖ Slide_09: 6 patches, 78 nuclei instances

üìÇ Processing Slide_10 (4 images) -> tnbc_test


                                                         

   ‚úÖ Slide_10: 4 patches, 191 nuclei instances

üìÇ Processing Slide_11 (3 images) -> tnbc_test


                                                         

   ‚úÖ Slide_11: 3 patches, 140 nuclei instances

‚úÖ PROCESSING COMPLETE
Total slides processed: 11
Total patches created: 50
Total nuclei instances: 1630

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
üìä Per-Split Statistics:
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ

TNBC_TRAIN:
   Slides: [1, 2, 3, 4, 5, 6, 7]
   Total patches: 33
   Total nuclei: 1088
   Avg nuclei/patch: 33.0

TNBC_VAL:
   Slides: [8, 9]
   Total patches: 10
   Total nuclei: 211
   Avg nuclei/patch: 21.1

TNBC_TEST:
   Slides: [10, 11]
   Total patches: 7
   Total nuclei: 331
   Avg nuclei/patch: 47.3

Output directory: /rsrch9/home/plm/idso_fa1_pathology/TIER2/yasin-vitaminp/tnbc/za

In [10]:
import zarr
import numpy as np
import matplotlib.pyplot as plt
import os
import random
from pathlib import Path

# --- CONFIGURATION ---
ZARR_DATA_ROOT = Path("/rsrch9/home/plm/idso_fa1_pathology/TIER2/yasin-vitaminp/tnbc/zarr_data")

def colorize_instances(mask):
    """
    Creates a random RGB color map for an instance mask.
    Background (0) is always Black.
    Each unique instance ID gets a unique color.
    """
    max_id = int(mask.max())
    if max_id == 0:
        return np.zeros((*mask.shape, 3), dtype=np.uint8)
    
    # Create a random color for every possible ID
    # Using a fixed seed so colors are consistent for the same mask, 
    # but random across IDs.
    np.random.seed(42) 
    colors = np.random.randint(50, 255, size=(max_id + 1, 3), dtype=np.uint8)
    
    # Force background to black
    colors[0] = [0, 0, 0]
    
    # Map the IDs to colors
    colored_mask = colors[mask]
    return colored_mask

def get_instance_stats(mask):
    """Get statistics about instances in the mask"""
    unique_ids = np.unique(mask)
    num_instances = len(unique_ids[unique_ids > 0])  # Exclude background
    
    instance_sizes = []
    for inst_id in unique_ids:
        if inst_id > 0:
            size = np.sum(mask == inst_id)
            instance_sizes.append(size)
    
    return {
        'num_instances': num_instances,
        'min_size': min(instance_sizes) if instance_sizes else 0,
        'max_size': max(instance_sizes) if instance_sizes else 0,
        'avg_size': np.mean(instance_sizes) if instance_sizes else 0
    }

def visualize_dataset(root_path, samples_per_category=5):
    """
    Visualize random samples from the TNBC zarr dataset.
    Works with Slide_XX directories.
    """
    
    print("=" * 70)
    print(f"üëÄ INSPECTING TNBC ZARR DATA AT: {root_path}")
    print("=" * 70)
    
    if not root_path.exists():
        print(f"\n‚ùå Path not found: {root_path}")
        return
    
    # Get all slide directories
    slides = sorted([d for d in root_path.iterdir() if d.is_dir() and d.name.startswith('Slide_')])
    
    if not slides:
        print(f"\n‚ö†Ô∏è  No slides found in {root_path}")
        return
    
    # Randomly select slides
    chosen = random.sample(slides, min(len(slides), samples_per_category))
    
    print(f"\n{'‚îÄ' * 70}")
    print(f"üìÇ TNBC DATASET")
    print(f"   Total slides: {len(slides)}")
    print(f"   Showing: {len(chosen)} random slides")
    print(f"{'‚îÄ' * 70}")
    
    for slide_idx, slide_path in enumerate(chosen, 1):
        slide_name = slide_path.name
        
        try:
            # Open zarr arrays
            z_img = zarr.open(str(slide_path / "images.zarr"), mode='r')
            z_msk = zarr.open(str(slide_path / "nuclei_masks.zarr"), mode='r')
            
            # Get metadata if available
            metadata_path = slide_path / "metadata.csv"
            if metadata_path.exists():
                import pandas as pd
                metadata = pd.read_csv(metadata_path)
            else:
                metadata = None
            
            # Pick a random patch from this slide
            num_patches = z_img.shape[0]
            idx = random.randint(0, num_patches - 1)
            
            img_patch = z_img[idx]
            mask_patch = z_msk[idx]
            
            # Get instance statistics
            stats = get_instance_stats(mask_patch)
            
            # Print info
            print(f"\n   [{slide_idx}/{len(chosen)}] {slide_name}")
            print(f"        Total patches: {num_patches}")
            print(f"        Selected patch: #{idx}")
            print(f"        Nuclei in patch: {stats['num_instances']}")
            if stats['num_instances'] > 0:
                print(f"        Size range: {stats['min_size']}-{stats['max_size']} pixels (avg: {stats['avg_size']:.0f})")
            
            if metadata is not None:
                patch_meta = metadata.iloc[idx]
                print(f"        Original file: {patch_meta['original_file']}")
                print(f"        Position: x={patch_meta['x']}, y={patch_meta['y']}")
                print(f"        Original size: {patch_meta['original_width']}x{patch_meta['original_height']}")
            
            # --- PLOTTING ---
            fig, ax = plt.subplots(1, 3, figsize=(18, 6))
            fig.suptitle(
                f"{slide_name} | Patch #{idx} | {stats['num_instances']} nuclei",
                fontsize=14,
                fontweight='bold'
            )
            
            # 1. Original Image
            ax[0].imshow(img_patch)
            ax[0].set_title("Original Image (H&E)", fontsize=12)
            ax[0].axis('off')
            
            # 2. Instance Mask (Random Colors)
            colored_mask = colorize_instances(mask_patch)
            ax[1].imshow(colored_mask)
            ax[1].set_title(
                f"Instance Mask\n({stats['num_instances']} nuclei, IDs: 1-{mask_patch.max()})",
                fontsize=12
            )
            ax[1].axis('off')
            
            # 3. Overlay
            # Create overlay with transparency
            overlay_img = img_patch.copy().astype(float)
            
            # Create alpha channel based on mask
            alpha = np.where(mask_patch > 0, 0.5, 0.0)  # 50% transparency for nuclei
            
            # Blend the colored mask with the original image
            for c in range(3):
                overlay_img[:, :, c] = (
                    img_patch[:, :, c] * (1 - alpha) + 
                    colored_mask[:, :, c] * alpha
                )
            
            overlay_img = overlay_img.astype(np.uint8)
            
            ax[2].imshow(overlay_img)
            ax[2].set_title("Overlay (Image + Instances)", fontsize=12)
            ax[2].axis('off')
            
            plt.tight_layout()
            plt.show()
            print(f"        ‚úÖ Visualization complete")
            
        except Exception as e:
            print(f"\n        ‚ùå Error reading {slide_name}: {e}")
            import traceback
            traceback.print_exc()
    
    print("\n" + "=" * 70)
    print("‚úÖ Dataset inspection complete!")
    print("=" * 70)

def print_dataset_summary(root_path):
    """Print summary statistics for the entire TNBC dataset"""
    
    print("\n" + "=" * 70)
    print("üìä TNBC DATASET SUMMARY")
    print("=" * 70)
    
    if not root_path.exists():
        print(f"‚ùå Path not found: {root_path}")
        return
    
    slides = sorted([d for d in root_path.iterdir() if d.is_dir() and d.name.startswith('Slide_')])
    
    if not slides:
        print("‚ö†Ô∏è  No slides found")
        return
    
    total_patches = 0
    total_nuclei = 0
    slide_info = []
    
    for slide_path in slides:
        try:
            z_msk = zarr.open(str(slide_path / "nuclei_masks.zarr"), mode='r')
            num_patches = z_msk.shape[0]
            total_patches += num_patches
            
            # Count nuclei in all patches
            slide_nuclei = 0
            for patch_idx in range(z_msk.shape[0]):
                mask = z_msk[patch_idx]
                unique_ids = np.unique(mask)
                slide_nuclei += len(unique_ids[unique_ids > 0])
            
            total_nuclei += slide_nuclei
            slide_info.append({
                'name': slide_path.name,
                'patches': num_patches,
                'nuclei': slide_nuclei
            })
        except Exception as e:
            print(f"‚ö†Ô∏è  Error processing {slide_path.name}: {e}")
    
    print(f"\nTotal slides: {len(slides)}")
    print(f"Total patches: {total_patches}")
    print(f"Total nuclei: {total_nuclei}")
    if total_patches > 0:
        print(f"Avg nuclei/patch: {total_nuclei/total_patches:.1f}")
    
    print(f"\n{'‚îÄ' * 70}")
    print("Per-slide breakdown:")
    print(f"{'‚îÄ' * 70}")
    for info in slide_info:
        avg_per_patch = info['nuclei'] / info['patches'] if info['patches'] > 0 else 0
        print(f"   {info['name']}: {info['patches']} patches, {info['nuclei']} nuclei (avg: {avg_per_patch:.1f}/patch)")
    
    print("\n" + "=" * 70)

def verify_data_integrity(root_path):
    """Check for common issues in the dataset"""
    print("\n" + "=" * 70)
    print("üîç DATA INTEGRITY CHECK")
    print("=" * 70)
    
    if not root_path.exists():
        print(f"‚ùå Path not found: {root_path}")
        return
    
    slides = sorted([d for d in root_path.iterdir() if d.is_dir() and d.name.startswith('Slide_')])
    
    issues = []
    
    for slide_path in slides:
        slide_name = slide_path.name
        
        # Check for required files
        if not (slide_path / "images.zarr").exists():
            issues.append(f"{slide_name}: Missing images.zarr")
        if not (slide_path / "nuclei_masks.zarr").exists():
            issues.append(f"{slide_name}: Missing nuclei_masks.zarr")
        if not (slide_path / "metadata.csv").exists():
            issues.append(f"{slide_name}: Missing metadata.csv")
        
        try:
            z_img = zarr.open(str(slide_path / "images.zarr"), mode='r')
            z_msk = zarr.open(str(slide_path / "nuclei_masks.zarr"), mode='r')
            
            # Check shape consistency
            if z_img.shape[0] != z_msk.shape[0]:
                issues.append(f"{slide_name}: Image/mask count mismatch ({z_img.shape[0]} vs {z_msk.shape[0]})")
            
            # Check patch dimensions
            if z_img.shape[1:3] != (512, 512):
                issues.append(f"{slide_name}: Incorrect image patch size {z_img.shape[1:3]}")
            if z_msk.shape[1:3] != (512, 512):
                issues.append(f"{slide_name}: Incorrect mask patch size {z_msk.shape[1:3]}")
            
            # Sample check: verify at least one patch has nuclei
            has_nuclei = False
            for idx in range(min(10, z_msk.shape[0])):  # Check first 10 patches
                if z_msk[idx].max() > 0:
                    has_nuclei = True
                    break
            
            if not has_nuclei:
                issues.append(f"{slide_name}: Warning - No nuclei found in first 10 patches")
                
        except Exception as e:
            issues.append(f"{slide_name}: Error reading zarr - {e}")
    
    if issues:
        print("\n‚ö†Ô∏è  Issues found:")
        for issue in issues:
            print(f"   ‚Ä¢ {issue}")
    else:
        print("\n‚úÖ All checks passed! Dataset looks good.")
    
    print("\n" + "=" * 70)

if __name__ == "__main__":
    if ZARR_DATA_ROOT.exists():
        # First verify data integrity
        verify_data_integrity(ZARR_DATA_ROOT)
        
        # Print overall summary
        print_dataset_summary(ZARR_DATA_ROOT)
        
        # Then visualize random samples
        print("\n")
        visualize_dataset(ZARR_DATA_ROOT, samples_per_category=10)
    else:
        print(f"‚ùå Path not found: {ZARR_DATA_ROOT}")
        print("   Please check your ZARR_DATA_ROOT configuration.")


üîç DATA INTEGRITY CHECK

‚úÖ All checks passed! Dataset looks good.


üìä TNBC DATASET SUMMARY
‚ö†Ô∏è  No slides found


üëÄ INSPECTING TNBC ZARR DATA AT: /rsrch9/home/plm/idso_fa1_pathology/TIER2/yasin-vitaminp/tnbc/zarr_data

‚ö†Ô∏è  No slides found in /rsrch9/home/plm/idso_fa1_pathology/TIER2/yasin-vitaminp/tnbc/zarr_data
