# Configure SBS Parameters

This notebook should be used as a test for ensuring correct SBS image loading and processing before running the SBS module.
Cells marked with <font color='red'>SET PARAMETERS</font> contain crucial variables that need to be set according to your specific experimental setup and data organization.
Please review and modify these variables as needed before proceeding with the analysis.

## <font color='red'>SET PARAMETERS</font>

### Fixed parameters for SBS processing

- `CONFIG_FILE_PATH`: Path to a Brieflow config file used during processing. Absolute or relative to where workflows are run from.

In [None]:
CONFIG_FILE_PATH = "config/config.yml"

## Imports

In [None]:
from pathlib import Path

import yaml
from tifffile import imread
import pandas as pd
from snakemake.io import expand
from microfilm.microplot import Microimage
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from lib.shared.configuration_utils import (
    CONFIG_FILE_HEADER,
    create_micropanel,
    random_cmap,
    image_segmentation_annotations,
    convert_tuples_to_lists,
)
from lib.shared.file_utils import get_filename
from lib.sbs.align_cycles import align_cycles, visualize_sbs_alignment
from lib.shared.log_filter import log_filter
from lib.sbs.compute_standard_deviation import compute_standard_deviation
from lib.sbs.max_filter import max_filter
from lib.sbs.find_peaks import find_peaks, find_peaks_spotiflow, plot_channels_with_peaks
from lib.shared.illumination_correction import apply_ic_field, combine_ic_images
from lib.shared.segment_cellpose import prepare_cellpose
from lib.cluster.scrape_benchmarks import get_uniprot_data
from lib.sbs.standardize_barcode_design import standardize_barcode_design, get_barcode_list
from lib.sbs.extract_bases import extract_bases
from lib.sbs.call_reads import call_reads, plot_normalization_comparison
from lib.sbs.call_cells import call_cells
from lib.shared.extract_phenotype_minimal import extract_phenotype_minimal
from lib.sbs.eval_mapping import (
    plot_mapping_vs_threshold,
    plot_cell_mapping_heatmap,
    plot_cell_metric_histogram,
    plot_gene_symbol_histogram,
    plot_barcode_prefix_matching,
)

## <font color='red'>SET PARAMETERS</font>

### Parameters for testing SBS processing

- `TEST_PLATE`, `TEST_WELL`, `TEST_TILE`: Plate/well/tile combination used for configuring parameters in this notebook.

### Channels

- `CHANNEL_NAMES`: A list of ordered names for each channel in your SBS image. 
- `CHANNEL_CMAPS`: A list of color maps to use when showing channel microimages. These need to be a Matplotlib or microfilm colormap. We recommend using: `["pure_red", "pure_green", "pure_blue", "pure_cyan", "pure_magenta", "pure_yellow"]`.

In [None]:
# Define test well and tile
TEST_PLATE = None
TEST_WELL = None
TEST_TILE = None
# Derive wildcard dictionary for testing
WILDCARDS = dict(well=TEST_WELL, tile=TEST_TILE)

# Define image channels with DAPI as the first channel
CHANNEL_NAMES = None
CHANNEL_CMAPS = None

# Remove DAPI channel to get bases
BASES = [channel for channel in CHANNEL_NAMES if channel in ["G", "T", "A", "C"]]
EXTRA_CHANNELS = [channel for channel in CHANNEL_NAMES if channel not in ["G", "T", "A", "C"]]

## <font color='red'>SET PARAMETERS</font>

  ### Alignment
  - `ALIGNMENT_METHOD`: Optional. Method for aligning SBS images between cycles. If not specified, the method will be automatically selected based on available channels, but can be overridden by the user:
      - `DAPI`: the DAPI channel is used for alignment between cycles. Automatically selected if **DAPI is present in each** round of SBS imaging.
      - `sbs_mean`: the mean intensity of base channels is used for alignment between cycles. Automatically selected if **DAPI is not present in each** round of SBS imaging.
  - `UPSAMPLE_FACTOR`: Factor used for subpixel alignment. Defaults to `2`
      - Higher values provide more precise alignment but increase processing time.
      - Set to `1` to disable subpixel alignment for faster processing.
  - `SKIP_CYCLES`: Optional. List of cycle indices to skip during alignment. Defaults to `None`
      - Use this to exclude problematic cycles that would interfere with alignment
      - Example: `[1]` to skip the first cycle, `[1, 6]` to skip cycles 1 and 6
      - Skipped cycles are completely removed from processing and will not appear in final results
  - `MANUAL_BACKGROUND_CYCLE`: Optional. Specific cycle to use as the source for segmentation background channels. Defaults to `None`
      - Use this when you have segmentation channels (e.g., DAPI, Vimentin) in a specific cycle that you want propagated to all cycles
      - If not specified, automatically selects the cycle with the most extra (non-GTAC) channels
      - Example: `3` to use cycle 3 as the source for background channels
      - Note: This refers to the original cycle number before any cycles are skipped
      - **Only works with simple channel configurations** (see note below)
  - `MANUAL_CHANNEL_MAPPING`: Optional. Explicit specification of which channels are present in each cycle. Defaults to `None`
      - Required when channel structure varies across cycles in non-standard ways
      - Provide a list of channel name lists, one per cycle, matching the order channels appear in each cycle's image data
      - The function will map channels by name to create the final output matching `CHANNEL_NAMES`

  **Note on Channel Configuration:**

  SBS images are commonly generated in two ways:
  1. **Same channels across all cycles** - All cycles have identical channel structure (e.g., all have DAPI, Vimentin, G, T, A, C)
  2. **Base channels + one background cycle** - Most cycles have only base channels (G, T, A, C), and one cycle has additional channels for segmentation (e.g., DAPI, Vimentin) at the beginning or end of the channel array

  If your imaging setup deviates from these approaches (e.g., some cycles missing channels that are in the middle of your target channel order), you must use `MANUAL_CHANNEL_MAPPING` to explicitly specify which channels are
  present in each cycle.

  **Example:** If `CHANNEL_NAMES = ["DAPI", "Vimentin", "G", "T", "A", "C"]` but only cycle 3 has Vimentin:

  ```python
  MANUAL_CHANNEL_MAPPING = [
      ["DAPI", "G", "T", "A", "C"],              # Cycle 1: no Vimentin
      ["DAPI", "G", "T", "A", "C"],              # Cycle 2: no Vimentin
      ["DAPI", "Vimentin", "G", "T", "A", "C"],  # Cycle 3: has Vimentin
      ["DAPI", "G", "T", "A", "C"],              # Cycle 4: no Vimentin
      # ... repeat for all cycles
  ]
  MANUAL_BACKGROUND_CYCLE = 3  # Source for Vimentin propagation
  ```

  The function will copy Vimentin from cycle 3 to all other cycles, creating output with shape (n_cycles, 6, height, width) where all cycles have the full `CHANNEL_NAMES` channel order.

In [None]:
ALIGNMENT_METHOD = None
UPSAMPLE_FACTOR = 2
SKIP_CYCLES = None
MANUAL_BACKGROUND_CYCLE = None
MANUAL_CHANNEL_MAPPING = None

In [None]:
# Load config file
with open(CONFIG_FILE_PATH, "r") as config_file:
    config = yaml.safe_load(config_file)

# Get paths to the sample files dfs
SBS_SAMPLES_FP = Path(config["preprocess"]["sbs_samples_fp"])
# Load the sample TSV files
sbs_samples = pd.read_csv(SBS_SAMPLES_FP, sep="\t")
SBS_CYCLES = sorted(list(sbs_samples["cycle"].unique()))
# Define cycles for testing if not None
SKIP_CYCLES_INDICES = [SBS_CYCLES.index(c) for c in SKIP_CYCLES] if SKIP_CYCLES is not None else None
MANUAL_BACKGROUND_CYCLE_INDEX = SBS_CYCLES.index(MANUAL_BACKGROUND_CYCLE) if MANUAL_BACKGROUND_CYCLE is not None else None

# Load test image data
print("Loading test images...")
ROOT_FP = Path(config["all"]["root_fp"])
PREPROCESS_FP = ROOT_FP / "preprocess"
sbs_test_image_paths = expand(
    PREPROCESS_FP
    / "images"
    / "sbs"
    / get_filename(
        {"plate": TEST_PLATE, "well": TEST_WELL, "tile": TEST_TILE, "cycle": "{cycle}"},
        "image",
        "tiff",
    ),
    cycle=SBS_CYCLES,
)
sbs_test_images = [imread(file_path) for file_path in sbs_test_image_paths]

# Align cycles
print("Aligning test images...")
aligned = align_cycles(
    sbs_test_images,
    channel_order=CHANNEL_NAMES,
    method=ALIGNMENT_METHOD,
    upsample_factor=1,
    skip_cycles=SKIP_CYCLES_INDICES,
    manual_background_cycle=MANUAL_BACKGROUND_CYCLE_INDEX,
    manual_channel_mapping=MANUAL_CHANNEL_MAPPING,
)

# Create and display micropanel of aligned images
print("Example aligned image for first cycle:")
aligned_microimages = [
    Microimage(
        aligned[0, i, :, :], channel_names=CHANNEL_NAMES[i], cmaps=CHANNEL_CMAPS[i]
    )
    for i in range(aligned.shape[1])
]
aligned_panel = create_micropanel(aligned_microimages, add_channel_label=True)
plt.show()

# NOTE: You can also loop through all your cycles to display micropanels to be sure of alignment by uncommenting the following lines:
# for cycle_idx in range(aligned.shape[0]):  # Adjust this range if you have a different number of cycles
#     print(f"Example aligned image for cycle {cycle_idx + 1}:")
#     aligned_microimages = [
#         Microimage(
#             aligned[cycle_idx, i, :, :], channel_names=CHANNEL_NAMES[i], cmaps=CHANNEL_CMAPS[i]
#         )
#         for i in range(aligned.shape[1])
#     ]
#     aligned_panel = create_micropanel(aligned_microimages, add_channel_label=True)
#     plt.show()

### Visualize Cycle-to-Cycle Alignment (Optional)

Verify base channels are properly aligned across cycles. Shows 3 locations (corner, center, random) with DAPI reference (grayscale) and base channels from different cycles (RGB overlay). Color fringing indicates misalignment.

- `DAPI_REFERENCE_CYCLE`: Cycle index for DAPI anatomical reference (shown as grayscale)
- `VIZ_CHANNELS`: List of 3 `(cycle_idx, channel_name)` tuples for RGB overlay (e.g., `[(0, "G"), (5, "T"), (10, "A")]`)

In [None]:
# Configure DAPI reference cycle (shown as grayscale background)
DAPI_REFERENCE_CYCLE = 0

# Configure base channels to overlay as RGB (cycle_index, channel_name)
VIZ_CHANNELS = None  # Example: [(0, "G"), (5, "T"), (10, "A")]

if VIZ_CHANNELS is not None:
    print("Visualizing alignment...")
    alignment_fig = visualize_sbs_alignment(
        aligned, 
        CHANNEL_NAMES,
        DAPI_REFERENCE_CYCLE,
        VIZ_CHANNELS,
        crop_size=300
    )
    plt.show()
else:
    print("Skipping visualization (VIZ_CHANNELS not set)")

## <font color='red'>SET PARAMETERS</font>

### Spot detection

- `MAX_FILTER_WIDTH`: Width parameter used for determining the neighborhood size for finding local maxima. No default value, but `3` is suggested
- `SPOT_DETECTION_METHOD`: Methodology for calling spots: 
    - `STANDARD`: Standard method for calling spots, involving Laplacian of Gaussian correction, and computation of spots across cycles. Spots are identified based on signal intensity and consistency across cycles.
        - `PEAK_WIDTH`: Width parameter used for peak detection, sets neighborhood size for finding local maxima of base channels standard deviation. Defaults to `5`
    - `SPOTIFLOW`: Deep learning based methodology for calling spots. Spots are called independently on all 4 bases of a selected cycle, and then coalesced to ensure minimum distance between spots. If this method is selected, the following parameters are required:
        - `SPOTIFLOW_CYCLE_INDEX`: Cycle to use for spot detection
        - `SPOTIFLOW_MODEL`: Model to use for spot detection (e.g., "general", "hybiss")
        - `SPOTIFLOW_THRESHOLD`: Probability threshold for confidence in spot detection
        - `SPOTIFLOW_MIN_DISTANCE`: Minimum distance (in pixels) required between detected spots


In [None]:
# Set parameters:
MAX_FILTER_WIDTH = 3
SPOT_DETECTION_METHOD = "standard"

if SPOT_DETECTION_METHOD == "standard":
   PEAK_WIDTH = 5

elif SPOT_DETECTION_METHOD == "spotiflow":
   SPOTIFLOW_CYCLE_INDEX = 0
   SPOTIFLOW_MODEL = "general"
   SPOTIFLOW_THRESHOLD = 0.3         
   SPOTIFLOW_MIN_DISTANCE = 1

# Derive extra channel indices
EXTRA_CHANNEL_INDICES = [CHANNEL_NAMES.index(channel) for channel in EXTRA_CHANNELS]

In [None]:
print("Detecting candidate reads...")

print("Applying Laplacian-of-Gaussian (LoG) filter...")
loged = log_filter(aligned, skip_index=EXTRA_CHANNEL_INDICES)

print("Applying max filter...")
maxed = max_filter(loged, width=MAX_FILTER_WIDTH, remove_index=EXTRA_CHANNEL_INDICES)

if SPOT_DETECTION_METHOD == "standard":
    print("Computing standard deviation over cycles...")
    standard_deviation = compute_standard_deviation(loged, remove_index=EXTRA_CHANNEL_INDICES)
    
    print("Finding peaks using standard method...")
    peaks = find_peaks(standard_deviation, width=PEAK_WIDTH)
    
elif SPOT_DETECTION_METHOD == "spotiflow":
    print(f"Finding peaks using Spotiflow (model: {SPOTIFLOW_MODEL})...")
    peaks, _ = find_peaks_spotiflow(
        aligned_images=aligned, 
        cycle_idx=SPOTIFLOW_CYCLE_INDEX,
        model=SPOTIFLOW_MODEL,
        prob_thresh=SPOTIFLOW_THRESHOLD,
        min_distance=SPOTIFLOW_MIN_DISTANCE,
        verbose=True
    )

# Create and display micropanel of max filtered datas
print("Example max filtered image for first cycle:")
maxed_microimages = [
    Microimage(maxed[1, i, :, :], channel_names=BASES[i], cmaps=CHANNEL_CMAPS[i + 1])
    for i in range(maxed.shape[1])
]
maxed_panel = create_micropanel(maxed_microimages, add_channel_label=True)
plt.show()

# Plot spots on base channels
fig, axes = plot_channels_with_peaks(
    maxed,                 
    peaks,
    BASES,        
    cycle_number=0,   
    threshold_peaks=200 if SPOT_DETECTION_METHOD == "standard" else None,
    peak_colors=['orange'],
    peak_labels=['Detected Peaks']
)
plt.show()

## <font color='red'>SET PARAMETERS</font>

### Illumination Correction and Segmentation

- `DAPI_CYCLE`: Cycle number where DAPI was imaged. Used for illumination correction of the DAPI channel.
- `CYTO_CYCLE`: Cycle number where the cytoplasmic/background channel was imaged. Used for illumination correction of the cytoplasmic channel.
- `CYTO_CHANNEL`: The name of the cytoplasmic channel used for cell boundary detection. 
  - When `DAPI_CYCLE == CYTO_CYCLE`: Can be any channel imaged in that cycle (e.g., "Vimentin", "GFP", "Phalloidin")
  - When `DAPI_CYCLE != CYTO_CYCLE`: Should be a base channel ("G", "T", "A", or "C") since DAPI was imaged separately

**Note**: Only set `DAPI_CYCLE != CYTO_CYCLE` when DAPI was imaged in a separate cycle from your cytoplasmic signal. This typically occurs when using a base channel (G, T, A, C) for cytoplasm detection because no cellular stain was available in the DAPI cycle.

In [None]:
# Define cycle to use for segmentation
DAPI_CYCLE = None
CYTO_CYCLE = None
# Define cytoplasmic channel
CYTO_CHANNEL = None

# Derive DAPI and CYTO indexes
DAPI_INDEX = CHANNEL_NAMES.index("DAPI")
CYTO_INDEX = CHANNEL_NAMES.index(CYTO_CHANNEL)
CYTO_CYCLE_INDEX = SBS_CYCLES.index(CYTO_CYCLE) - len([skip for skip in (SKIP_CYCLES or []) if skip < SBS_CYCLES.index(CYTO_CYCLE)])

# Validate DAPI and CYTO cycles and channels
if DAPI_CYCLE != CYTO_CYCLE and CYTO_CHANNEL not in BASES:
    raise ValueError(
        f"When DAPI_CYCLE ({DAPI_CYCLE}) != CYTO_CYCLE ({CYTO_CYCLE}), "
        f"CYTO_CHANNEL should be a base channel {BASES}, but got '{CYTO_CHANNEL}'. "
        f"If using a cellular stain, set DAPI_CYCLE = CYTO_CYCLE."
    )

In [None]:
# Determine the cycle to use for segmentation as the cyt
aligned_image_data_segmentation_cycle = aligned[CYTO_CYCLE_INDEX]

print("Applying illumination correction to segmentation cycle image...")

# Logic based on whether DAPI and CYTO come from same or different cycles
if DAPI_CYCLE != CYTO_CYCLE:
    # Different cycles - need to combine IC fields
    ic_field_dapi = imread(
        PREPROCESS_FP / "ic_fields" / "sbs" / 
        get_filename({"plate": TEST_PLATE, "well": TEST_WELL, "cycle": DAPI_CYCLE}, "ic_field", "tiff")
    )
    ic_field_cyto = imread(
        PREPROCESS_FP / "ic_fields" / "sbs" / 
        get_filename({"plate": TEST_PLATE, "well": TEST_WELL, "cycle": CYTO_CYCLE}, "ic_field", "tiff")
    )
    # Combine IC fields - need to figure out the right indices for combine_ic_images
    ic_field = combine_ic_images([ic_field_dapi, ic_field_cyto], [EXTRA_CHANNEL_INDICES, None])
else:
    # Same cycle - use one IC field
    ic_field = imread(
        PREPROCESS_FP / "ic_fields" / "sbs" / 
        get_filename({"plate": TEST_PLATE, "well": TEST_WELL, "cycle": DAPI_CYCLE}, "ic_field", "tiff")
    )

# Apply illumination correction field
corrected_image = apply_ic_field(
    aligned_image_data_segmentation_cycle, correction=ic_field
)

# Prepare corrected image for CellPose segmentation
# NOTE: this process is done during the `segment_cellpose`` method below as well
# Use the prepared_cellpose image to test CellPose (see below)
print("Preparing IC segmentation image for CellPose...")
cellpose_rgb = prepare_cellpose(
    corrected_image,
    DAPI_INDEX,
    CYTO_INDEX,
)

# show max filtered data for one round
print("Pre-segmentation images:")
pre_seg_microimages = [
    Microimage(cellpose_rgb[2], channel_names="Dapi"),
    Microimage(cellpose_rgb[1], channel_names="Cyto"),
]
pre_seg_panel = create_micropanel(pre_seg_microimages, add_channel_label=True)
plt.show()

## <font color='red'>SET PARAMETERS</font>

### Segmentation

**IMPORTANT: GPU Recommendation for CPSAM**
If testing the CPSAM model (`cellpose_model="cpsam"`), we strongly recommend:
- Using a GPU-enabled machine (`GPU=True`)
- Allocating sufficient time (segmentation can take 30+ minutes per tile)
- Consider running this notebook in a GPU-enabled environment or testing on a smaller region

#### Select Segmentation Method
- `SEGMENTATION_METHOD`: Choose from "cellpose", "stardist", or "watershed" for cell segmentation.

#### Common Parameters
- `GPU`: Set to True to use GPU acceleration (if available).
- `RECONCILE`: Method for reconciling nuclei and cell masks (typically "contained_in_cells", which allows more than one nucleus per cell and is useful for cells that are dividing).
- `SEGMENT_CELLS`: Whether to segment cells, or only segment nuclei. If spots are contained in nuclei, there is no need to segment cell bodies. This may speed up analysis.

#### Cellpose Parameters (if using "cellpose")
- `CELLPOSE_MODEL`: CellPose model to use. Options: "cyto3" (default), "cyto2", "cyto", or "cpsam" (requires Cellpose 4.x).
- `CELL_FLOW_THRESHOLD` & `NUCLEI_FLOW_THRESHOLD`: Flow threshold for Cellpose segmentation. Default is 0.4.
- `CELL_CELLPROB_THRESHOLD` & `NUCLEI_CELLPROB_THRESHOLD`: Cell probability threshold for Cellpose. Default is 0.
- `HELPER_INDEX`: (Optional) Index of additional channel to help with CPSAM segmentation. Only used with `cellpose_model="cpsam"`. Default is None.
- Note: For Cellpose 3.x models (cyto3, cyto2), nuclei and cell diameters will be estimated automatically. For CPSAM (Cellpose 4.x), diameters can be left as None and will be estimated from initial segmentation results.

#### StarDist Parameters (if using "stardist")
- `STARDIST_MODEL`: StarDist model type. Default is "2D_versatile_fluo".
- `CELL_PROB_THRESHOLD` & `NUCLEI_PROB_THRESHOLD`: Probability threshold for segmentation. Default is 0.479071.
- `CELL_NMS_THRESHOLD` & `NUCLEI_NMS_THRESHOLD`: Non-maximum suppression threshold. Default is 0.3.

#### Watershed Parameters (if using "watershed")
- `THRESHOLD_DAPI`: Threshold for nuclei segmentation. 
- `THRESHOLD_CELL`: Threshold for cell boundary segmentation.
- `NUCLEUS_AREA`: Range for filtering nuclei by area as a tuple (min, max).

In [None]:
# Select segmentation method
SEGMENTATION_METHOD = "cellpose"

# Common parameters
GPU = False
RECONCILE = "contained_in_cells"
SEGMENT_CELLS = True

if SEGMENTATION_METHOD == "cellpose":
    # Parameters for CellPose method
    CELLPOSE_MODEL = "cyto3"
    NUCLEI_FLOW_THRESHOLD = 0.4
    NUCLEI_CELLPROB_THRESHOLD = 0.0
    CELL_FLOW_THRESHOLD = 1 
    CELL_CELLPROB_THRESHOLD = 0
    HELPER_INDEX = None  # Optional: channel index to help with CPSAM segmentation

    # Only estimate diameters for non-CPSAM models
    if CELLPOSE_MODEL != "cpsam":
        from lib.shared.segment_cellpose import estimate_diameters
        print("Estimating optimal cell and nuclei diameters...")
        NUCLEI_DIAMETER, CELL_DIAMETER = estimate_diameters(
            corrected_image,
            dapi_index=DAPI_INDEX,
            cyto_index=CYTO_INDEX,
            cellpose_model=CELLPOSE_MODEL,
        )
    else:
        print("CPSAM model selected. Initial diameters set to None.")
        print("Note: Diameters will be estimated automatically from segmentation results in the next cell.")
        NUCLEI_DIAMETER = None  # Will be estimated from segmentation
        CELL_DIAMETER = None    # Will be estimated from segmentation

elif SEGMENTATION_METHOD == "stardist":
    # Parameters for StarDist method
    STARDIST_MODEL = "2D_versatile_fluo"
    NUCLEI_PROB_THRESHOLD = 0.479071
    NUCLEI_NMS_THRESHOLD = 0.3
    CELL_PROB_THRESHOLD = 0.479071
    CELL_NMS_THRESHOLD = 0.3

elif SEGMENTATION_METHOD == "watershed":
    # Parameters for Watershed method
    THRESHOLD_DAPI = 4260 
    THRESHOLD_CELL = 1300
    NUCLEUS_AREA = (45,450)

In [None]:
print(f"Segmenting image with {SEGMENTATION_METHOD}...")

if SEGMENTATION_METHOD == "cellpose":
    from lib.shared.segment_cellpose import segment_cellpose
    result = segment_cellpose(
        corrected_image,
        dapi_index=DAPI_INDEX,
        cyto_index=CYTO_INDEX,
        nuclei_diameter=NUCLEI_DIAMETER,
        cell_diameter=CELL_DIAMETER,
        cellpose_kwargs=dict(
            nuclei_flow_threshold=NUCLEI_FLOW_THRESHOLD,
            nuclei_cellprob_threshold=NUCLEI_CELLPROB_THRESHOLD,
            cell_flow_threshold=CELL_FLOW_THRESHOLD,
            cell_cellprob_threshold=CELL_CELLPROB_THRESHOLD,
        ),
        cellpose_model=CELLPOSE_MODEL,
        helper_index=HELPER_INDEX,
        gpu=GPU,
        reconcile=RECONCILE,
        cells=SEGMENT_CELLS,
    )

elif SEGMENTATION_METHOD == "stardist":
    from lib.shared.segment_stardist import segment_stardist
    result = segment_stardist(
        corrected_image,
        dapi_index=DAPI_INDEX,
        cyto_index=CYTO_INDEX,
        model_type=STARDIST_MODEL,
        stardist_kwargs=dict(
            nuclei_prob_threshold=NUCLEI_PROB_THRESHOLD,
            nuclei_nms_threshold=NUCLEI_NMS_THRESHOLD,
            cell_prob_threshold=CELL_PROB_THRESHOLD,
            cell_nms_threshold=CELL_NMS_THRESHOLD,
        ),
        gpu=GPU,
        reconcile=RECONCILE,
        cells=SEGMENT_CELLS,
    )

elif SEGMENTATION_METHOD == "watershed":
    from lib.shared.segment_watershed import segment_watershed
    result = segment_watershed(
        corrected_image,
        nuclei_threshold=THRESHOLD_DAPI,
        nuclei_area_min=NUCLEUS_AREA[0],
        nuclei_area_max=NUCLEUS_AREA[1],
        cell_threshold=THRESHOLD_CELL,
        cells=SEGMENT_CELLS,
    )

# Handle unpacking based on SEGMENT_CELLS
if SEGMENT_CELLS:
    nuclei, cells = result
else:
    nuclei = result
    cells = nuclei  # Use nuclei as cells for downstream code compatibility

# Create and display micropanel of nuclei segmention
print("Example microplots for DAPI channel and nuclei segmentation:")
nuclei_cmap = random_cmap(num_colors=len(np.unique(nuclei)))
nuclei_seg_microimages = [
    Microimage(cellpose_rgb[2], channel_names="Dapi"),
    Microimage(nuclei, cmaps=nuclei_cmap, channel_names="Nuclei"),
]
nuclei_seg_panel = create_micropanel(nuclei_seg_microimages, add_channel_label=True)
plt.show()

# Create and display micropanel of segmented cells
print(f"Example microplots for merged channels and {'cells' if SEGMENT_CELLS else 'nuclei'} segmentation:")
cells_cmap = random_cmap(num_colors=len(np.unique(cells if SEGMENT_CELLS else nuclei)))
cells_seg_microimages = [
    Microimage(cellpose_rgb, channel_names="Merged"),
    Microimage(cells if SEGMENT_CELLS else nuclei, cmaps=cells_cmap, channel_names="Cells" if SEGMENT_CELLS else "Nuclei"),
]
cells_seg_panel = create_micropanel(cells_seg_microimages, add_channel_label=True)
plt.show()

# Create and display micropanel of annotated data
print("Example microplot for sequencing data annotated with segmentation:")
annotated_data = image_segmentation_annotations(cellpose_rgb[1:], nuclei, cells if SEGMENT_CELLS else nuclei)
annotated_microimage = [
    Microimage(
        annotated_data,
        channel_names="Merged",
        cmaps=["pure_blue", "pure_red", "pure_green"],
    )
]
annotated_panel = create_micropanel(
    annotated_microimage, num_cols=1, figscaling=10, add_channel_label=False
)
plt.show()

if SEGMENTATION_METHOD == "cellpose" and CELLPOSE_MODEL == "cpsam":
    from skimage.measure import regionprops
    import numpy as np

    # Calculate nuclei diameters
    nuclei_props = regionprops(nuclei)
    nuclei_diameters = [prop.equivalent_diameter for prop in nuclei_props]
    estimated_nuclei_diameter = np.mean(nuclei_diameters)
    print(f"Nuclei - Average diameter: {estimated_nuclei_diameter:.2f} pixels")

    # Calculate cell diameters  
    cells_props = regionprops(cells)
    cells_diameters = [prop.equivalent_diameter for prop in cells_props]
    estimated_cell_diameter = np.mean(cells_diameters)
    print(f"Cells - Average diameter: {estimated_cell_diameter:.2f} pixels")
    
    # Update the diameter variables for config
    NUCLEI_DIAMETER = estimated_nuclei_diameter
    CELL_DIAMETER = estimated_cell_diameter
    print(f"\nUpdated NUCLEI_DIAMETER to {NUCLEI_DIAMETER:.2f} pixels")
    print(f"Updated CELL_DIAMETER to {CELL_DIAMETER:.2f} pixels")

Note: You may want to adjust these parameters and run segmentation tests if you feel you are capturing too little or too much area for the masks. For cellpose, the nuclei and cell diameters will be automatically estimated, but can be manually adjusted if needed. You manually can set `NUCLEI_DIAMETER` and `CELL_DIAMETER` and rerun the above blocks as many times as needed.

## <font color='red'>SET PARAMETERS</font>

### Barcode design table standardization

Raw barcode design tables from different sources often have inconsistent formatting, column names, and gene annotations that need to be cleaned and validated before analysis. This standardization step transforms your raw design file into a consistent format with validated gene symbols, standardized column names, and proper barcode prefixes for read mapping.

**Barcode Type Selection:**
- `BARCODE_TYPE`: Choose "simple" (single-barcode protocol) or "multi" (multi-barcode with MAP/RECOMB regions)
- **Default**: "simple" for standard protocols
- See Cell 32 for detailed explanation of barcode types and when to use each mode

**Essential Parameters:**
- `DF_DESIGN_FP`: File path to your raw guide RNA design file (TSV format)
- `DF_BARCODE_LIBRARY_FP`: File path where the cleaned, standardized barcode library will be saved
- `UNIPROT_DATA_FP`: File path for temporary UniProt annotation data (automatically generated and deleted)
- `GENE_SYMBOL_COL`: Column name containing gene symbols (e.g., "gene_symbol", "target_gene"). Set to `None` if unavailable
- `GENE_ID_COL`: Column name containing gene IDs (e.g., "gene_id", "ensembl_id"). Set to `None` if not needed

**Simple Mode Parameters:**
- `BARCODE_COL`: Column containing full barcode sequences (e.g., "sgRNA", "guide_sequence")
- `PREFIX_LENGTH`: Total barcode length BEFORE skipping cycles
  - Final length = PREFIX_LENGTH - number of skipped cycles
  - Example: 13 bases → skip [2,3,4] → 10 final bases
- `SKIP_CYCLES_MAP`: 1-based cycle positions to skip (e.g., `[2, 3, 4]`)
  - **IMPORTANT**: Must match `SKIP_CYCLES` from Cell 8 (imaging alignment)
  - Set to `None` if no cycles were skipped

**Multi Mode Parameters:**
- `PREFIX_MAP`: Column with MAP region barcode sequences (e.g., "iBAR2")
- `PREFIX_RECOMB`: Column with RECOMB region barcode sequences (optional)
- `MAP_PREFIX_LENGTH`: Total bases BEFORE skipping for MAP region
  - Final length = MAP_PREFIX_LENGTH - len(SKIP_CYCLES_MAP)
- `RECOMB_PREFIX_LENGTH`: Total bases BEFORE skipping for RECOMB region
  - Final length = RECOMB_PREFIX_LENGTH - len(SKIP_CYCLES_RECOMB)
- `SKIP_CYCLES_MAP`: 1-based positions to skip in MAP region (optional)
- `SKIP_CYCLES_RECOMB`: 1-based positions to skip in RECOMB region (optional)

**Non-Targeting Control Parameters:**
- `NONTARGETING_FORMAT`: Format string for standardized non-targeting names (default: "nontargeting_{prefix}")
  - Use `{prefix}` for barcode prefix, `{original}` for original name
- `NONTARGETING_PATTERNS`: List of patterns to identify non-targeting controls (default: ["nontargeting", "sg_nt", "non-targeting"])

**Note:** For complex scenarios (custom prefix generation), you can use custom prefix functions (`prefix_map_func`, `prefix_recomb_func`). See function documentation for details.

In [None]:
# Define barcode type
BARCODE_TYPE = "simple"

# Essential parameters (common to both modes)
DF_DESIGN_FP = None
DF_BARCODE_LIBRARY_FP = "config/barcode_library.tsv"
UNIPROT_DATA_FP = "config/uniprot_data.tsv"
GENE_SYMBOL_COL = None
GENE_ID_COL = None

# Non-targeting control parameters
NONTARGETING_FORMAT = "nontargeting_{prefix}"  # Use {prefix} for barcode prefix, {original} for original name
NONTARGETING_PATTERNS = ["nontargeting", "sg_nt", "non-targeting"]  # Patterns to identify controls

# Mode-specific barcode parameters
if BARCODE_TYPE == "simple":
    # Simple mode: single barcode column and prefix length
    BARCODE_COL = None  # Column with full barcode sequences
    
    # IMPORTANT: PREFIX_LENGTH should be the FULL barcode length in your design library
    # If you skipped cycles during imaging (Cell 8: SKIP_CYCLES), use SKIP_CYCLES_MAP below
    PREFIX_LENGTH = None  # Full barcode length (e.g., 12 for 12-base barcodes)
    
    # If you skipped cycles during imaging, set this to match SKIP_CYCLES from Cell 8
    # Example: SKIP_CYCLES = [1, 6] in Cell 8 → SKIP_CYCLES_MAP = [1, 6] here
    SKIP_CYCLES_MAP = None  # Set to list of 1-based cycle numbers to skip, or None if no cycles skipped
    
    # Multi-mode parameters set to None
    PREFIX_MAP = None
    PREFIX_RECOMB = None
    MAP_PREFIX_LENGTH = None
    RECOMB_PREFIX_LENGTH = None
    SKIP_CYCLES_RECOMB = None
    SEQUENCING_ORDER = None

elif BARCODE_TYPE == "multi":
    # Multi mode: separate columns for MAP and RECOMB regions
    PREFIX_MAP = None  # Column with MAP region barcode sequences (e.g., "iBAR2")
    PREFIX_RECOMB = None  # Column with RECOMB region barcode sequences (optional)
    MAP_PREFIX_LENGTH = None  # Length of MAP prefix to extract (e.g., 6 for cycles 0-5)
    RECOMB_PREFIX_LENGTH = None  # Length of RECOMB prefix to extract (e.g., 6 for cycles 6-11)
    SKIP_CYCLES_MAP = None  # Optional: list of 1-based cycles to skip in MAP region
    SKIP_CYCLES_RECOMB = None  # Optional: list of 1-based cycles to skip in RECOMB region
    SEQUENCING_ORDER = None  # Sequencing order for barcode extraction (e.g., ["G", "T", "A", "C"])
    
    # Simple mode parameters set to None
    BARCODE_COL = None
    PREFIX_LENGTH = None

In [None]:
# Get uniprot data and save it temporarily
uniprot_data = get_uniprot_data()
uniprot_data.to_csv(UNIPROT_DATA_FP, sep="\t", index=False)
uniprot_data = pd.read_csv(UNIPROT_DATA_FP, sep="\t")

# Read design table
print("Loading and standardizing barcode design table...")
df_design = pd.read_csv(DF_DESIGN_FP, sep="\t")

# Call standardize_barcode_design with mode-specific parameters
if BARCODE_TYPE == "simple":
    # Simple mode: use barcode_col and prefix_length (legacy parameter names)
    df_barcode_library = standardize_barcode_design(
        df_design,
        prefix_map=BARCODE_COL,  # In simple mode, this is the barcode column
        gene_symbol_col=GENE_SYMBOL_COL,
        gene_id_col=GENE_ID_COL,
        map_prefix_length=PREFIX_LENGTH,  # In simple mode, this is the prefix length
        skip_cycles_map=SKIP_CYCLES_MAP,  # Pass skip_cycles for simple mode
        uniprot_data_path=UNIPROT_DATA_FP,
        nontargeting_format=NONTARGETING_FORMAT,
        nontargeting_patterns=NONTARGETING_PATTERNS,
    )
    
elif BARCODE_TYPE == "multi":
    # Multi mode: use prefix_map, prefix_recomb, and region-specific lengths
    df_barcode_library = standardize_barcode_design(
        df_design,
        prefix_map=PREFIX_MAP,
        prefix_recomb=PREFIX_RECOMB,
        gene_symbol_col=GENE_SYMBOL_COL,
        gene_id_col=GENE_ID_COL,
        map_prefix_length=MAP_PREFIX_LENGTH,
        recomb_prefix_length=RECOMB_PREFIX_LENGTH,
        skip_cycles_map=SKIP_CYCLES_MAP,
        skip_cycles_recomb=SKIP_CYCLES_RECOMB,
        uniprot_data_path=UNIPROT_DATA_FP,
        nontargeting_format=NONTARGETING_FORMAT,
        nontargeting_patterns=NONTARGETING_PATTERNS,
    )

# Delete uniprot data file
Path(UNIPROT_DATA_FP).unlink(missing_ok=True)

# Save standardized design table
df_barcode_library.to_csv(DF_BARCODE_LIBRARY_FP, sep="\t", index=False)
print(f"Standardized barcode design saved to: {DF_BARCODE_LIBRARY_FP}")
display(df_barcode_library)

# Extract barcodes (prefixes) for mapping - conditional based on barcode type
if BARCODE_TYPE == "multi":
    barcodes = get_barcode_list(df_barcode_library, sequencing_order=SEQUENCING_ORDER)
else:
    barcodes = get_barcode_list(df_barcode_library)
    
print(f"Extracted {len(barcodes)} barcode prefixes for read mapping")

## <font color='red'>SET PARAMETERS</font>

### Extract base intensity, call reads, assign to cells
- `THRESHOLD_READS`: Initial threshold for detecting sequencing reads, set to ~50 for preliminary analysis. This parameter will be optimized based on the mapping rate vs. peak threshold plot generated below. A higher threshold increases confidence in read calls but reduces the total number of detected reads.
- `CALL_READS_METHOD`: Method to use for correcting base intensity across channels. The below `plot_normalization_comparison` function will help you assess what method to use. Options are:
    - `MEDIAN`: Uses median-based correction, performed independently for each tile. This is the default method.
    - `PERCENTILE`: Uses percentile-based correction, performed independently for each tile.

In [None]:
# Define parameters for calling reads
THRESHOLD_READS = 50
CALL_READS_METHOD = "median"

In [None]:
# Run extract_bases and call_reads with the default threshold
df_bases = extract_bases(
    peaks, maxed, cells if SEGMENT_CELLS else nuclei, THRESHOLD_READS, wildcards=WILDCARDS, bases=BASES
)
df_reads = call_reads(df_bases, peaks_data=peaks, method=CALL_READS_METHOD)

### Determine Optimal Read Threshold

**These plots show READ-LEVEL metrics** (not final cell-level mapping):
- **Blue line**: Fraction of reads matching expected barcodes
- **Orange solid**: Total reads with valid barcodes  
- **Orange dotted**: Unique cells with ≥1 valid barcode (not necessarily singlets)

Use these to set `THRESHOLD_READS` to maximize clean reads. Final cell-level QC comes later.

In [None]:
print("Mapping rate vs. peak threshold for determining optimal peak cutoff:")
plot_mapping_vs_threshold(df_reads, barcodes, "peak")
plt.show()

**How to read these plots:**
- **Left**: All reads (including background noise)
- **Right**: Cell-associated reads only (cell > 0)

**Goal**: Find threshold where mapping rate plateaus (~70-80%) while retaining enough reads/cells for analysis.

**Note**: High read mapping ≠ high cell mapping. Many reads might cluster in few cells, or cells might have mixed barcodes. See cell-level QC below for final mapping quality.

In [None]:
THRESHOLD_READS = None

In [None]:
# Re-run extract_bases and call_reads with the new threshold
print("Extracted bases:")
df_bases = extract_bases(
    peaks, maxed, cells if SEGMENT_CELLS else nuclei, THRESHOLD_READS, wildcards=WILDCARDS, bases=BASES
)
display(df_bases)

print("Base intensity across cycles:")
ax = sns.pointplot(x="cycle", y="intensity", hue="channel", data=df_bases)
plt.show()

print("Intensity for each base:")
ax = sns.boxplot(
    x="channel", y="intensity", hue="channel", data=df_bases, showfliers=False
)
plt.show()

print("Different normalization methods")
plot_normalization_comparison(df_bases, base_order = BASES)

In [None]:
print("Called reads:")
df_reads = call_reads(df_bases, peaks_data=peaks)
display(df_reads)

print("Barcode prefix matching (actual vs random):")
_, _ = plot_barcode_prefix_matching(df_reads, df_barcode_library)
plt.show()

## <font color='red'>SET PARAMETERS</font>

### Read Prioritization Method: Peak vs. Count

The `SORT_CALLS` parameter determines how barcodes are prioritized when assigning reads to cells. This choice depends on your sequencing protocol:

**Count Prioritization (`SORT_CALLS = "count"`):**
- Prioritizes barcodes based on the **number of spots** detected per cell
- **Recommended for mRNA barcode protocols** (e.g., IVT-based perturbations)
- Why: mRNA barcodes produce multiple spots per cell, so more spots = more confident call
- Best for protocols where barcode signal is distributed throughout the cell

**Peak Prioritization (`SORT_CALLS = "peak"`):**
- Prioritizes barcodes based on the **peak intensity** of spots
- **Recommended for DNA barcode protocols** (e.g., Zombie, T7 amplification)
- Why: DNA barcodes typically produce a singular, bright spot per cell
- Best for protocols with focused, high-intensity signal

**Default**: `"count"` is the default and works well for most applications.

### Read Mapping Parameters

**Common to both barcode modes:**
- `Q_MIN`: Minimum quality score for base reads (default: 0)
- `ERROR_CORRECT`: Enable read error correction (default: False)
- `SORT_CALLS`: Method for prioritizing barcodes - 'count' for mRNA protocols, 'peak' for DNA protocols
- `MAX_DISTANCE`: Maximum edit distance for barcode matching (optional)

**Simple mode specific:**
- `BARCODE_COL`: Column in barcode library with full sequences (default: "sgRNA")
- `PREFIX_COL`: Column with pre-computed prefixes, or None for auto-truncation

**Multi mode specific:**
- `MAP_START`, `MAP_END`: Cycle positions defining MAP region for first barcode
- `RECOMB_START`, `RECOMB_END`: Cycle positions defining RECOMB region
- `MAP_COL`, `RECOMB_COL`: Columns in barcode library for MAP/RECOMB sequences
- `RECOMB_FILTER_COL`, `RECOMB_Q_THRESH`: Optional recombination filtering parameters

The Q_min plot below helps determine optimal sequence quality cutoff:

In [None]:
print("Mapping rate vs. Q_min for determining optimal sequence quality cutoff:")
plot_mapping_vs_threshold(df_reads, barcodes, "Q_min")
plt.show()

#### Left Plot (All Reads):
- Shows how Q_min threshold affects all detected reads
- Blue line: Mapping rate (fraction of reads matching expected barcodes)
- Solid red line: Total number of mapped spots (reads with valid barcodes)
- Dotted red line: Number of unique cells with at least one mapped read

#### Right Plot (Cell-Associated Reads Only):
- Shows the same metrics but only for reads associated with cells

#### Interpreting Q_min Results:
With our optimized peak threshold, these plots confirm that adjusting Q_min provides little benefit:
- The mapping rate (blue line) is already very high at Q_min = 0
- Increasing Q_min only marginally improves mapping rate 
- However, this comes at a significant cost:
  - Total mapped spots and mapped cells decreases substantially
- The small gain in mapping rate doesn't justify the large loss of data

In [None]:
# Common parameters for both modes
Q_MIN = 0
ERROR_CORRECT = False
SORT_CALLS = "count"
MAX_DISTANCE = None  # Optional: maximum edit distance for barcode matching

if BARCODE_TYPE == "simple":
    # Simple mode: standard single-barcode protocol
    BARCODE_COL = "sgRNA"
    PREFIX_COL = "prefix"
    
    # Multi-mode parameters set to None
    MAP_START = None
    MAP_END = None
    MAP_COL = None
    RECOMB_START = None
    RECOMB_END = None
    RECOMB_COL = None
    RECOMB_FILTER_COL = None
    RECOMB_Q_THRESH = None
    BARCODE_INFO_COLS = None

elif BARCODE_TYPE == "multi":
    # Multi mode: multi-barcode protocol with recombination detection
    # Define cycle positions for MAP and RECOMB regions
    MAP_START = None     # Start cycle for MAP region (e.g., 0)
    MAP_END = None       # End cycle for MAP region (e.g., 5)
    RECOMB_START = None  # Start cycle for RECOMB region (e.g., 6)
    RECOMB_END = None    # End cycle for RECOMB region (e.g., 11)
    
    # Define columns in barcode library
    MAP_COL = "map_prefix"          # Column with MAP region sequences
    RECOMB_COL = "recomb_prefix"    # Column with RECOMB region sequences
    
    # Optional: recombination filtering
    RECOMB_FILTER_COL = None  # Column to filter by (e.g., "no_recomb_0")
    RECOMB_Q_THRESH = None    # Quality threshold for recombination calls
    
    # Optional: additional barcode info columns to include in output
    BARCODE_INFO_COLS = None  # e.g., ["barcode_full", "gene_family"]
    
    # Simple mode parameters set to None
    BARCODE_COL = None
    PREFIX_COL = None

else:
    raise ValueError(f"BARCODE_TYPE must be 'simple' or 'multi', got '{BARCODE_TYPE}'")

In [None]:
print("Calling cells with barcode mapping...")

if BARCODE_TYPE == "simple":
    # Simple mode: single-barcode protocol (call cells directly)
    df_cells = call_cells(
        df_reads, 
        df_barcode_library=df_barcode_library, 
        q_min=Q_MIN, 
        barcode_col=BARCODE_COL, 
        prefix_col=PREFIX_COL, 
        error_correct=ERROR_CORRECT,
        sort_calls=SORT_CALLS,
        max_distance=MAX_DISTANCE,
    )
    
elif BARCODE_TYPE == "multi":
    # Multi mode: prep reads first, then call cells
    from lib.sbs.call_cells import prep_multi_reads
    
    print("Preparing multi-barcode reads...")
    df_reads_prepped = prep_multi_reads(
        df_reads,
        map_start=MAP_START,
        map_end=MAP_END,
        recomb_start=RECOMB_START,
        recomb_end=RECOMB_END,
        map_col=MAP_COL,
        recomb_col=RECOMB_COL,
    )
    
    print("Calling cells with multi-barcode detection...")
    df_cells = call_cells(
        reads_data=df_reads_prepped,
        df_barcode_library=df_barcode_library,
        q_min=Q_MIN,
        map_start=MAP_START,
        map_end=MAP_END,
        map_col=MAP_COL,
        recomb_start=RECOMB_START,
        recomb_end=RECOMB_END,
        recomb_col=RECOMB_COL,
        recomb_filter_col=RECOMB_FILTER_COL,
        recomb_q_thresh=RECOMB_Q_THRESH,
        error_correct=ERROR_CORRECT,
        sort_calls=SORT_CALLS,
        max_distance=MAX_DISTANCE,
        barcode_info_cols=BARCODE_INFO_COLS,
    )

print(f"Called {len(df_cells)} cells using {BARCODE_TYPE} mode")
display(df_cells)

print("Minimal phenotype features:")
df_sbs_info = extract_phenotype_minimal(
    phenotype_data=nuclei, nuclei_data=nuclei, wildcards=WILDCARDS
)
display(df_sbs_info)

print("Summary of the fraction of cells mapping to one barcode:")
one_barcode_mapping = plot_cell_mapping_heatmap(
    df_cells,
    df_sbs_info,
    barcodes,
    mapping_to="one",
    mapping_strategy="gene symbols",
    shape="6W_sbs",
    return_plot=False,
    return_summary=True,
)
display(one_barcode_mapping)

print("Summary of the fraction of cells mapping to any barcode:")
any_barcode_mapping = plot_cell_mapping_heatmap(
    df_cells,
    df_sbs_info,
    barcodes,
    mapping_to="any",
    mapping_strategy="gene symbols",
    shape="6W_sbs",
    return_plot=False,
    return_summary=True,
)
display(any_barcode_mapping)

print("Histogram of the number of reads per cell:")
outliers = plot_cell_metric_histogram(df_cells, sort_by=SORT_CALLS)
plt.show()

print("Histogram of the number of counts of each unique gene symbols:")
outliers = plot_gene_symbol_histogram(df_cells)
plt.show()

## Automated Parameter Optimization (Optional)

This section provides an automated way to optimize SBS parameters by testing multiple parameter combinations and evaluating their performance. This is particularly useful when you're unsure about the optimal values for `PEAK_WIDTH` and `THRESHOLD_READS`.

**How it works:**
1. Define a grid of parameters to test (e.g., different peak widths and threshold values)
2. For each combination, the full SBS pipeline is run automatically
3. Results are evaluated using a metric function
4. Multi-panel visualizations show the primary optimization metric alongside secondary metrics to help you understand trade-offs
5. After identifying optimal parameters, subset the results and re-evaluate with detailed plots

**Available metrics:**
- `metric_specificity`: (Recommended) Ratio of cells with exactly one barcode to cells with any barcode. Higher is better (1.0 = perfect, all mapped cells have exactly one barcode).
- `metric_one_barcode_fraction`: Fraction of all cells that map to exactly one barcode. Useful when maximizing total mapped cells is a priority.
- `metric_any_barcode_fraction`: Fraction of all cells that map to any barcode (one or more). Useful for understanding overall mapping coverage.

**Note:** This can take several minutes depending on the number of combinations tested. Start with a small grid for initial exploration.

In [None]:
# # Uncomment to run automated parameter search

# from lib.sbs.automated_parameter_search import (
#     automated_parameter_search,
#     visualize_parameter_results,
#     get_best_parameters,
#     metric_one_barcode_fraction,
#     metric_any_barcode_fraction,
#     metric_specificity,
# )

# # Define parameter grid to test
# # This is the primary input you'll customize for your analysis
# param_grid = {
#     'peak_width': [5, 10, 15, 20],           # Test different peak detection widths
#     'threshold_reads': [50, 100, 150, 200],  # Test different intensity thresholds
# }

In [None]:
# # Select optimization metric
# # metric_specificity is recommended as it balances quality (one barcode per cell) with coverage
# OPTIMIZATION_METRIC = metric_specificity

# # Define parameters that stay constant across all tests
# fixed_params = {
#     'max_filter_width': MAX_FILTER_WIDTH,
#     'call_reads_method': CALL_READS_METHOD,
#     'q_min': Q_MIN,
#     'error_correct': ERROR_CORRECT,
#     'sort_calls': SORT_CALLS,
# }

# # Add mode-specific parameters to fixed_params
# if BARCODE_TYPE == "simple":
#     fixed_params.update({
#         'barcode_col': BARCODE_COL,
#         'prefix_col': PREFIX_COL,
#     })
#     if MAX_DISTANCE is not None:
#         fixed_params['max_distance'] = MAX_DISTANCE
        
# elif BARCODE_TYPE == "multi":
#     fixed_params.update({
#         'map_start': MAP_START,
#         'map_end': MAP_END,
#         'map_col': MAP_COL,
#         'recomb_start': RECOMB_START,
#         'recomb_end': RECOMB_END,
#         'recomb_col': RECOMB_COL,
#         'recomb_filter_col': RECOMB_FILTER_COL,
#         'recomb_q_thresh': RECOMB_Q_THRESH,
#     })
#     if MAX_DISTANCE is not None:
#         fixed_params['max_distance'] = MAX_DISTANCE
#     if BARCODE_INFO_COLS is not None:
#         fixed_params['barcode_info_cols'] = BARCODE_INFO_COLS

In [None]:
# # Run automated parameter search and evaluation
# print(f"Testing {len(param_grid['peak_width']) * len(param_grid['threshold_reads'])} parameter combinations...")
# results_df, df_cells_all = automated_parameter_search(
#     aligned_images=aligned,
#     mask=cells if SEGMENT_CELLS else nuclei,  # Use 'cells' if spots are in cell bodies, 'nuclei' if spots are nuclear
#     barcodes=barcodes,
#     df_barcode_library=df_barcode_library,
#     wildcards=WILDCARDS,
#     bases=BASES,
#     extra_channel_indices=EXTRA_CHANNEL_INDICES,
#     param_grid=param_grid,
#     fixed_params=fixed_params,
#     metric_fn=OPTIMIZATION_METRIC,
#     barcode_type=BARCODE_TYPE,  # Pass the barcode type
#     verbose=True
# )

# # Visualize results with multiple metrics
# # This will show the primary metric (specificity) plus secondary metrics
# # (one-barcode mapping, any-barcode mapping) to understand trade-offs
# visualize_parameter_results(
#     results_df, 
#     df_cells=df_cells_all,  # Pass df_cells_all to compute secondary metrics
#     metric_name='metric_score',
#     show_secondary_metrics=True  # Show multiple metric panels
# )

# # Get best parameters based on the optimization metric
# best_params = get_best_parameters(results_df, metric_name='metric_score', maximize=True)

# # Subset df_cells_all to the best parameter combination for detailed evaluation
# print("\n" + "="*60)
# print("Evaluating best parameter combination in detail")
# print("="*60)
# df_cells_best = df_cells_all[
#     (df_cells_all['peak_width'] == best_params['peak_width']) &
#     (df_cells_all['threshold_reads'] == best_params['threshold_reads'])
# ].copy()

# # Remove parameter tracking columns for cleaner display
# df_cells_best = df_cells_best.drop(columns=['peak_width', 'threshold_reads'], errors='ignore')

# print(f"\nBest parameters produce {len(df_cells_best)} called cells")
# display(df_cells_best)

# # Re-run all evaluation plots with the best parameter set
# print("\nSummary of the fraction of cells mapping to one barcode:")
# one_barcode_mapping_best = plot_cell_mapping_heatmap(
#     df_cells_best,
#     df_sbs_info,
#     barcodes,
#     mapping_to="one",
#     mapping_strategy="gene symbols",
#     shape="6W_sbs",
#     return_plot=False,
#     return_summary=True,
# )
# display(one_barcode_mapping_best)

# print("\nSummary of the fraction of cells mapping to any barcode:")
# any_barcode_mapping_best = plot_cell_mapping_heatmap(
#     df_cells_best,
#     df_sbs_info,
#     barcodes,
#     mapping_to="any",
#     mapping_strategy="gene symbols",
#     shape="6W_sbs",
#     return_plot=False,
#     return_summary=True,
# )
# display(any_barcode_mapping_best)

# print("\nHistogram of the number of reads per cell:")
# outliers_best = plot_cell_metric_histogram(df_cells_best, sort_by=SORT_CALLS)
# plt.show()

# print("\nHistogram of the number of counts of each unique gene symbols:")
# outliers_best = plot_gene_symbol_histogram(df_cells_best)
# plt.show()

# # Optionally: Update your parameters with the best values for use in the config
# # PEAK_WIDTH = int(best_params['peak_width'])
# # THRESHOLD_READS = int(best_params['threshold_reads'])

## Add sbs process parameters to config file

In [None]:
# Add sbs section
config["sbs"] = {
    "alignment_method": ALIGNMENT_METHOD,
    "channel_names": CHANNEL_NAMES,
    "upsample_factor": UPSAMPLE_FACTOR,
    "skip_cycles_indices": SKIP_CYCLES_INDICES,
    "manual_background_cycle_index": MANUAL_BACKGROUND_CYCLE_INDEX,
    "manual_channel_mapping": MANUAL_CHANNEL_MAPPING,
    "extra_channel_indices": EXTRA_CHANNEL_INDICES,
    "max_filter_width": MAX_FILTER_WIDTH,
    "spot_detection_method": SPOT_DETECTION_METHOD,
    "dapi_cycle": DAPI_CYCLE,
    "cyto_cycle": CYTO_CYCLE,
    "cyto_cycle_index": CYTO_CYCLE_INDEX,
    "dapi_index": DAPI_INDEX,
    "cyto_index": CYTO_INDEX,
    "segmentation_method": SEGMENTATION_METHOD,
    "gpu": GPU,
    "reconcile": RECONCILE,
    "segment_cells": SEGMENT_CELLS,
    "df_barcode_library_fp": DF_BARCODE_LIBRARY_FP,
    "threshold_peaks": THRESHOLD_READS,
    "call_reads_method": CALL_READS_METHOD,
    "bases": BASES,
    "q_min": Q_MIN,
    "error_correct": ERROR_CORRECT,
    "sort_calls": SORT_CALLS,
    "barcode_type": BARCODE_TYPE,
}

# Add max_distance if specified
if MAX_DISTANCE is not None:
    config["sbs"]["max_distance"] = MAX_DISTANCE

# Add barcode type specific parameters
if BARCODE_TYPE == "simple":
    config["sbs"].update({
        "barcode_col": BARCODE_COL,
        "prefix_col": PREFIX_COL,
    })
elif BARCODE_TYPE == "multi":
    config["sbs"].update({
        "map_start": MAP_START,
        "map_end": MAP_END,
        "map_col": MAP_COL,
        "recomb_start": RECOMB_START,
        "recomb_end": RECOMB_END,
        "recomb_col": RECOMB_COL,
    })
    # Add optional multi-mode parameters if defined
    if RECOMB_FILTER_COL is not None:
        config["sbs"]["recomb_filter_col"] = RECOMB_FILTER_COL
    if RECOMB_Q_THRESH is not None:
        config["sbs"]["recomb_q_thresh"] = RECOMB_Q_THRESH
    if BARCODE_INFO_COLS is not None:
        config["sbs"]["barcode_info_cols"] = BARCODE_INFO_COLS

# Add spot detection method specific parameters
if SPOT_DETECTION_METHOD == "standard":
    config["sbs"].update({
        "peak_width": PEAK_WIDTH,
    })
elif SPOT_DETECTION_METHOD == "spotiflow":
    config["sbs"].update({
        "spotiflow_cycle_index": SPOTIFLOW_CYCLE_INDEX,
        "spotiflow_model": SPOTIFLOW_MODEL,
        "spotiflow_threshold": SPOTIFLOW_THRESHOLD,
        "spotiflow_min_distance": SPOTIFLOW_MIN_DISTANCE,
    })

# Add method-specific parameters based on segmentation method
if SEGMENTATION_METHOD == "cellpose":
    config["sbs"].update({
        "nuclei_diameter": NUCLEI_DIAMETER,
        "cell_diameter": CELL_DIAMETER,
        "nuclei_flow_threshold": NUCLEI_FLOW_THRESHOLD,
        "nuclei_cellprob_threshold": NUCLEI_CELLPROB_THRESHOLD,
        "cell_flow_threshold": CELL_FLOW_THRESHOLD,
        "cell_cellprob_threshold": CELL_CELLPROB_THRESHOLD,
        "cellpose_model": CELLPOSE_MODEL,
    })
    # Add helper_index only if it's defined
    if HELPER_INDEX is not None:
        config["sbs"]["helper_index"] = HELPER_INDEX
elif SEGMENTATION_METHOD == "stardist":
    config["sbs"].update({
        "stardist_model": STARDIST_MODEL,
        "nuclei_prob_threshold": NUCLEI_PROB_THRESHOLD,
        "nuclei_nms_threshold": NUCLEI_NMS_THRESHOLD,
        "cell_prob_threshold": CELL_PROB_THRESHOLD,
        "cell_nms_threshold": CELL_NMS_THRESHOLD,
    })
elif SEGMENTATION_METHOD == "watershed":
    config["sbs"].update({
        "threshold_dapi": THRESHOLD_DAPI,
        "nucleus_area_min": NUCLEUS_AREA[0],
        "nucleus_area_max": NUCLEUS_AREA[1],
        "threshold_cell": THRESHOLD_CELL,
    })

# Convert tuples to lists before dumping
safe_config = convert_tuples_to_lists(config)

# Write the updated configuration back with markdown-style comments
with open(CONFIG_FILE_PATH, "w") as config_file:
    # Write the introductory markdown-style comments
    config_file.write(CONFIG_FILE_HEADER)

    # Dump the updated YAML structure, keeping markdown comments for sections
    yaml.dump(safe_config, config_file, default_flow_style=False, sort_keys=False)