In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cellpose Interactive Demo\n",
    "\n",
    "**Author:** Janan Arslan  \n",
    "**Date:** 2025-06-30  \n",
    "**License:** CC BY-NC-ND 4.0\n",
    "\n",
    "This notebook provides an interactive demonstration of Cellpose for cell segmentation.\n",
    "\n",
    "## Contents\n",
    "1. Installation\n",
    "2. Loading sample images\n",
    "3. Basic segmentation\n",
    "4. Comparing models\n",
    "5. Parameter exploration\n",
    "6. Visualizing results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Installation\n",
    "\n",
    "First, let's install Cellpose. This might take a few minutes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install Cellpose (without GUI for Binder compatibility)\n",
    "!pip install cellpose --quiet\n",
    "!pip install matplotlib --quiet\n",
    "!pip install ipywidgets --quiet\n",
    "\n",
    "print(\"Installation complete!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Import Libraries and Check Version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "from cellpose import models, io, plot, utils\n",
    "from cellpose.io import get_image_files\n",
    "import os\n",
    "from urllib.request import urlretrieve\n",
    "import ipywidgets as widgets\n",
    "from IPython.display import display, clear_output\n",
    "\n",
    "# Check Cellpose version\n",
    "import cellpose\n",
    "print(f\"Cellpose version: {cellpose.__version__}\")\n",
    "print(\"✓ All libraries imported successfully!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Load Sample Images\n",
    "\n",
    "Cellpose comes with built-in test images. Let's download and use them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download sample images from Cellpose GitHub\n",
    "sample_urls = {\n",
    "    'cells': 'https://raw.githubusercontent.com/MouseLand/cellpose/main/cellpose/test/assets/3channel_RGB.tif',\n",
    "    'nuclei': 'https://raw.githubusercontent.com/MouseLand/cellpose/main/cellpose/test/assets/rgb_2D.png',\n",
    "    'cyto': 'https://raw.githubusercontent.com/MouseLand/cellpose/main/cellpose/test/assets/2D_RGB.tif'\n",
    "}\n",
    "\n",
    "# Create directory for images\n",
    "os.makedirs('sample_images', exist_ok=True)\n",
    "\n",
    "# Download images\n",
    "images = {}\n",
    "for name, url in sample_urls.items():\n",
    "    filename = f'sample_images/{name}.tif' if name != 'nuclei' else f'sample_images/{name}.png'\n",
    "    try:\n",
    "        urlretrieve(url, filename)\n",
    "        images[name] = io.imread(filename)\n",
    "        print(f\"✓ Downloaded {name} image: shape {images[name].shape}\")\n",
    "    except:\n",
    "        print(f\"⚠ Could not download {name} image\")\n",
    "\n",
    "# If download fails, create synthetic data\n",
    "if len(images) == 0:\n",
    "    print(\"\\nCreating synthetic test image...\")\n",
    "    # Create a synthetic image with circles\n",
    "    img_size = 512\n",
    "    img = np.zeros((img_size, img_size, 3), dtype=np.uint8)\n",
    "    \n",
    "    # Add some circular \"cells\"\n",
    "    import cv2\n",
    "    for _ in range(20):\n",
    "        center = (np.random.randint(50, img_size-50), np.random.randint(50, img_size-50))\n",
    "        radius = np.random.randint(20, 40)\n",
    "        color = (np.random.randint(100, 255), np.random.randint(100, 255), np.random.randint(100, 255))\n",
    "        cv2.circle(img, center, radius, color, -1)\n",
    "        cv2.circle(img, center, radius, (0, 0, 0), 2)\n",
    "    \n",
    "    images['synthetic'] = img\n",
    "    print(\"✓ Created synthetic test image\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Visualize Sample Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display all loaded images\n",
    "n_images = len(images)\n",
    "fig, axes = plt.subplots(1, n_images, figsize=(5*n_images, 5))\n",
    "\n",
    "if n_images == 1:\n",
    "    axes = [axes]\n",
    "\n",
    "for idx, (name, img) in enumerate(images.items()):\n",
    "    axes[idx].imshow(img)\n",
    "    axes[idx].set_title(f'{name.capitalize()} Image\\nShape: {img.shape}')\n",
    "    axes[idx].axis('off')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Basic Cellpose Segmentation\n",
    "\n",
    "Let's run Cellpose on our first image using the 'cyto2' model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize Cellpose model\n",
    "model = models.Cellpose(model_type='cyto2')\n",
    "print(\"Model loaded successfully!\")\n",
    "\n",
    "# Get the first image\n",
    "img_name = list(images.keys())[0]\n",
    "img = images[img_name]\n",
    "\n",
    "# Run segmentation\n",
    "print(f\"\\nRunning segmentation on {img_name} image...\")\n",
    "masks, flows, styles, diams = model.eval(img, diameter=None, channels=[0,0])\n",
    "\n",
    "print(f\"✓ Segmentation complete!\")\n",
    "print(f\"  - Number of cells detected: {len(np.unique(masks)) - 1}\")\n",
    "print(f\"  - Estimated cell diameter: {diams:.1f} pixels\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Visualize Segmentation Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create visualization\n",
    "fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n",
    "\n",
    "# Original image\n",
    "axes[0,0].imshow(img)\n",
    "axes[0,0].set_title('Original Image')\n",
    "axes[0,0].axis('off')\n",
    "\n",
    "# Segmentation masks\n",
    "axes[0,1].imshow(masks, cmap='tab20')\n",
    "axes[0,1].set_title(f'Segmentation Masks\\n({len(np.unique(masks))-1} cells)')\n",
    "axes[0,1].axis('off')\n",
    "\n",
    "# Overlay\n",
    "overlay = plot.mask_overlay(img, masks)\n",
    "axes[0,2].imshow(overlay)\n",
    "axes[0,2].set_title('Overlay')\n",
    "axes[0,2].axis('off')\n",
    "\n",
    "# Flow fields\n",
    "axes[1,0].imshow(flows[0][0], cmap='RdBu_r')\n",
    "axes[1,0].set_title('Horizontal Flow (X)')\n",
    "axes[1,0].axis('off')\n",
    "\n",
    "axes[1,1].imshow(flows[0][1], cmap='RdBu_r')\n",
    "axes[1,1].set_title('Vertical Flow (Y)')\n",
    "axes[1,1].axis('off')\n",
    "\n",
    "# Cell probability\n",
    "axes[1,2].imshow(flows[0][2], cmap='plasma')\n",
    "axes[1,2].set_title('Cell Probability')\n",
    "axes[1,2].axis('off')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Compare Different Models\n",
    "\n",
    "Let's compare the three main Cellpose models: cyto, cyto2, and nuclei."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define models to compare\n",
    "model_types = ['cyto', 'cyto2', 'nuclei']\n",
    "results = {}\n",
    "\n",
    "# Run each model\n",
    "for model_type in model_types:\n",
    "    print(f\"Running {model_type} model...\")\n",
    "    model = models.Cellpose(model_type=model_type)\n",
    "    masks, flows, styles, diams = model.eval(img, diameter=None, channels=[0,0])\n",
    "    results[model_type] = {\n",
    "        'masks': masks,\n",
    "        'n_cells': len(np.unique(masks)) - 1,\n",
    "        'diameter': diams\n",
    "    }\n",
    "    print(f\"  ✓ Detected {results[model_type]['n_cells']} cells\")\n",
    "\n",
    "# Visualize comparisons\n",
    "fig, axes = plt.subplots(1, 4, figsize=(20, 5))\n",
    "\n",
    "# Original\n",
    "axes[0].imshow(img)\n",
    "axes[0].set_title('Original Image')\n",
    "axes[0].axis('off')\n",
    "\n",
    "# Model results\n",
    "for idx, model_type in enumerate(model_types):\n",
    "    overlay = plot.mask_overlay(img, results[model_type]['masks'])\n",
    "    axes[idx+1].imshow(overlay)\n",
    "    axes[idx+1].set_title(f'{model_type.capitalize()} Model\\n{results[model_type][\"n_cells\"]} cells')\n",
    "    axes[idx+1].axis('off')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Interactive Parameter Exploration\n",
    "\n",
    "Explore how different parameters affect segmentation results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create interactive widgets\n",
    "model_widget = widgets.Dropdown(\n",
    "    options=['cyto', 'cyto2', 'nuclei'],\n",
    "    value='cyto2',\n",
    "    description='Model:'\n",
    ")\n",
    "\n",
    "diameter_widget = widgets.IntSlider(\n",
    "    value=30,\n",
    "    min=0,\n",
    "    max=100,\n",
    "    step=5,\n",
    "    description='Diameter:',\n",
    "    tooltip='Set to 0 for automatic estimation'\n",
    ")\n",
    "\n",
    "flow_threshold_widget = widgets.FloatSlider(\n",
    "    value=0.4,\n",
    "    min=0.0,\n",
    "    max=3.0,\n",
    "    step=0.1,\n",
    "    description='Flow Threshold:'\n",
    ")\n",
    "\n",
    "cellprob_threshold_widget = widgets.FloatSlider(\n",
    "    value=0.0,\n",
    "    min=-6.0,\n",
    "    max=6.0,\n",
    "    step=0.5,\n",
    "    description='Cell Prob Threshold:'\n",
    ")\n",
    "\n",
    "output = widgets.Output()\n",
    "\n",
    "def update_segmentation(model_type, diameter, flow_threshold, cellprob_threshold):\n",
    "    with output:\n",
    "        clear_output(wait=True)\n",
    "        \n",
    "        # Run segmentation\n",
    "        model = models.Cellpose(model_type=model_type)\n",
    "        diameter_use = None if diameter == 0 else diameter\n",
    "        \n",
    "        masks, flows, styles, diams = model.eval(\n",
    "            img, \n",
    "            diameter=diameter_use, \n",
    "            channels=[0,0],\n",
    "            flow_threshold=flow_threshold,\n",
    "            cellprob_threshold=cellprob_threshold\n",
    "        )\n",
    "        \n",
    "        # Display results\n",
    "        fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
    "        \n",
    "        axes[0].imshow(img)\n",
    "        axes[0].set_title('Original')\n",
    "        axes[0].axis('off')\n",
    "        \n",
    "        axes[1].imshow(masks, cmap='tab20')\n",
    "        axes[1].set_title(f'Masks ({len(np.unique(masks))-1} cells)')\n",
    "        axes[1].axis('off')\n",
    "        \n",
    "        overlay = plot.mask_overlay(img, masks)\n",
    "        axes[2].imshow(overlay)\n",
    "        axes[2].set_title('Overlay')\n",
    "        axes[2].axis('off')\n",
    "        \n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "        \n",
    "        print(f\"Estimated diameter: {diams:.1f} pixels\")\n",
    "\n",
    "# Create interactive interface\n",
    "interact = widgets.interactive(\n",
    "    update_segmentation,\n",
    "    model_type=model_widget,\n",
    "    diameter=diameter_widget,\n",
    "    flow_threshold=flow_threshold_widget,\n",
    "    cellprob_threshold=cellprob_threshold_widget\n",
    ")\n",
    "\n",
    "display(interact, output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Batch Processing Example\n",
    "\n",
    "Process all loaded images and create a summary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process all images\n",
    "model = models.Cellpose(model_type='cyto2')\n",
    "batch_results = {}\n",
    "\n",
    "print(\"Processing all images...\\n\")\n",
    "\n",
    "for img_name, img in images.items():\n",
    "    print(f\"Processing {img_name}...\")\n",
    "    masks, flows, styles, diams = model.eval(img, diameter=None, channels=[0,0])\n",
    "    \n",
    "    batch_results[img_name] = {\n",
    "        'masks': masks,\n",
    "        'n_cells': len(np.unique(masks)) - 1,\n",
    "        'diameter': diams,\n",
    "        'image': img\n",
    "    }\n",
    "    print(f\"  ✓ Found {batch_results[img_name]['n_cells']} cells\\n\")\n",
    "\n",
    "# Create summary visualization\n",
    "n_imgs = len(batch_results)\n",
    "fig, axes = plt.subplots(2, n_imgs, figsize=(5*n_imgs, 10))\n",
    "\n",
    "if n_imgs == 1:\n",
    "    axes = axes.reshape(-1, 1)\n",
    "\n",
    "for idx, (name, result) in enumerate(batch_results.items()):\n",
    "    # Original\n",
    "    axes[0, idx].imshow(result['image'])\n",
    "    axes[0, idx].set_title(f'{name.capitalize()}\\nOriginal')\n",
    "    axes[0, idx].axis('off')\n",
    "    \n",
    "    # Segmentation\n",
    "    overlay = plot.mask_overlay(result['image'], result['masks'])\n",
    "    axes[1, idx].imshow(overlay)\n",
    "    axes[1, idx].set_title(f\"{result['n_cells']} cells\\nØ {result['diameter']:.1f}px\")\n",
    "    axes[1, idx].axis('off')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10. Export Results\n",
    "\n",
    "Save segmentation masks for further analysis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create output directory\n",
    "os.makedirs('output', exist_ok=True)\n",
    "\n",
    "# Save masks for each image\n",
    "for name, result in batch_results.items():\n",
    "    # Save mask\n",
    "    mask_filename = f'output/{name}_masks.npy'\n",
    "    np.save(mask_filename, result['masks'])\n",
    "    \n",
    "    # Save overlay visualization\n",
    "    fig, ax = plt.subplots(figsize=(8, 8))\n",
    "    overlay = plot.mask_overlay(result['image'], result['masks'])\n",
    "    ax.imshow(overlay)\n",
    "    ax.set_title(f\"{name.capitalize()} - {result['n_cells']} cells detected\")\n",
    "    ax.axis('off')\n",
    "    plt.savefig(f'output/{name}_overlay.png', dpi=150, bbox_inches='tight')\n",
    "    plt.close()\n",
    "    \n",
    "    print(f\"✓ Saved results for {name}\")\n",
    "\n",
    "print(\"\\nAll results saved to 'output' directory!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "In this notebook, we've covered:\n",
    "\n",
    "1. ✅ Installing Cellpose in a Jupyter environment\n",
    "2. ✅ Loading and visualizing sample images\n",
    "3. ✅ Running basic segmentation with Cellpose\n",
    "4. ✅ Understanding the flow fields and probability maps\n",
    "5. ✅ Comparing different pre-trained models\n",
    "6. ✅ Interactive parameter exploration\n",
    "7. ✅ Batch processing multiple images\n",
    "8. ✅ Exporting results for further analysis\n",
    "\n",
    "### Next Steps\n",
    "\n",
    "- Try with your own images\n",
    "- Explore Cellpose 2.0's custom training features\n",
    "- Test Cellpose 3.0's image restoration capabilities\n",
    "- Integrate with QuPath for whole-slide analysis\n",
    "\n",
    "### Resources\n",
    "\n",
    "- [Cellpose Documentation](https://cellpose.readthedocs.io/)\n",
    "- [Cellpose GitHub](https://github.com/MouseLand/cellpose)\n",
    "- [Cellpose Paper](https://www.nature.com/articles/s41592-020-01018-x)\n",
    "- [Interactive Demo](https://www.cellpose.org)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}