# Train Classifier

This notebook provides tools for training custom cell or vacuole classifiers. It covers labeling data, training models, and selecting the best classifier.

Cells marked with <font color='red'>SET PARAMETERS</font> contain crucial variables that need to be set according to your specific experimental setup and data organization.
Please review and modify these variables as needed before proceeding with the analysis.

## <font color='red'>SET PARAMETERS</font>

### Fixed parameters for aggregate module

- `CONFIG_FILE_PATH`: Path to a Brieflow config file used during processing. Absolute or relative to where workflows are run from.

In [None]:
CONFIG_FILE_PATH = "config/config.yml"

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import pyarrow.parquet as pq
import yaml

from lib.aggregate.cell_classification import CellClassifier
from lib.aggregate.cell_data_utils import split_cell_data
from lib.classify.shared import get_latest_run_dir
from lib.classify.apply import (
    build_master_phenotype_df,
    build_montages_and_summary,
    display_pngs_in_plots_and_list_models,
    launch_rankline_ui,
    resolve_classifier_model_dill_path,
    show_model_evaluation_pngs,
)
from lib.classify.calibration import calibrate_confidence
from lib.classify.labeling import (
    _render_next_batch,
    build_class_mapping,
    consolidate_manual_classifications,
    filter_existing_from_pools,
    initialize_labeling_state,
    load_existing_training_data,
    prepare_mask_dataframes,
    resolve_channel_colors,
)
from lib.classify.train import (
    filter_classes,
    load_cellprofiler_data,
    train_classifier_pipeline,
)
from lib.classify.path_utils import get_parquet_config, find_sample_parquet
from lib.shared.configuration_utils import CONFIG_FILE_HEADER
from lib.shared.file_utils import get_filename

_KEYS = ["plate", "well", "tile", "mask_label"]

In [None]:
with open(CONFIG_FILE_PATH, "r") as config_file:
    config = yaml.safe_load(config_file)

ROOT_FP = Path(config["all"]["root_fp"])
CLASSIFIER_OUTPUT_DIR = ROOT_FP / "classifier"
CLASSIFIER_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

## 1. Labeling

This section  uses an interactive UI to label objects (ex. cells) to create training data for machine learning models.

**Steps:** **1a)** Configure classification settings → **1b)** Set gating parameters to choose what objects will be displayed in the visualizer → **1c)** Configure display options → Label objects interactively → Save training dataset

### 1a. <font color='red'>SET PARAMETERS</font>: Classification Settings

This section defines the training dataset and classes that will be used for labeling.

**Classification parameters:**
- `TRAINING_DATA_SOURCE`: Source of training data features. Ex. `"phenotype"` or `"merge"`.
- `MODE`: Object to classify. Default: `"cell"`.
- `CLASS_TITLE`: Name of the new column added to the phenotype dataframe.
- `CLASSIFICATION`: List of categories for classification. Categories appear as 1, 2, 3... in output, corresponding to list order.
- `PLATES_TO_CLASSIFY`: List of plates to include in classification.
- `WELLS_TO_CLASSIFY`: List of wells among the specified plates to classify.
  
**If adding to an existing training dataset:**
- `ADD_TRAINING_DATA`: Set to `False` for first-time training, `True` to add to an existing dataset.
- `EXISTING_TRAINING_DATA`: Only set if `ADD_TRAINING_DATA` is `True`. Specify the filepath of the existing training dataset.
- `RELABEL_CLASSIFICATIONS`: Set to `True` to revisit and modify labels from the existing dataset. Previously labeled data will be shown first.

In [None]:
ADD_TRAINING_DATA = False
EXISTING_TRAINING_DATA = None
RELABEL_CLASSIFICATIONS = True

TRAINING_DATA_SOURCE = "merge"
MODE = "cell"
CLASS_TITLE = "cell_stage"
CLASSIFICATION = ["Mitotic", "Interphase"]
PLATES_TO_CLASSIFY = [1]
WELLS_TO_CLASSIFY = ["A1"]

class_mapping = build_class_mapping(CLASSIFICATION)

# Set data_source based on training_data_source
if TRAINING_DATA_SOURCE == "merge":
    data_source = ROOT_FP / "merge"
else:
    data_source = ROOT_FP / "phenotype"

In [None]:
# Peview available columns in data source
plate_set = [str(p) for p in PLATES_TO_CLASSIFY]
well_set = list(WELLS_TO_CLASSIFY)

# Determine parquet directory based on data source
parquet_dir, name_suffix = get_parquet_config(
    mode=MODE,
    source=TRAINING_DATA_SOURCE,
    root_fp=ROOT_FP
)

# Find first available parquet file to get column schema
sample_pq = find_sample_parquet(
    plates=plate_set,
    wells=well_set,
    parquet_dir=parquet_dir,
    name_suffix=name_suffix
)

# Preview columns
if sample_pq:
    all_cols = pq.ParquetFile(sample_pq).schema.names
    print(f"Data source: {TRAINING_DATA_SOURCE}")
    print(f"Found {len(all_cols)} columns")
    for col in all_cols: 
        print(col)
else:
    print(f"Warning: No parquet files found for specified plates/wells")
    all_cols = []

### 1b. <font color='red'>SET PARAMETERS</font>: Gating Settings

Use the parameters below to choose what subset of objects to display in the labeling interface.

**Feature gating (optional):**
Use gating to prioritize cells with specific characteristics for labeling. For example, to find mitotic cells, one might set `GATE_FEATURE = "nucleus_DAPI_mad"` with a high `GATE_MIN_PERCENTILE` to prioritize cells with bright nuclei.

- `GATE_FEATURE`: Feature to gate by (e.g., `"nucleus_DAPI_mad"`). Set to `None` to skip gating.
- `GATE_MIN`: Minimum value (exclusive). Cells below this will be deprioritized.
- `GATE_MAX`: Maximum value (exclusive). Cells above this will be deprioritized.
- `GATE_MIN_PERCENTILE`: Percentile minimum (0-1). Cells below this percentile will be deprioritized.
- `GATE_MAX_PERCENTILE`: Percentile maximum (0-1). Cells above this percentile will be deprioritized.

**Batch settings:**
- `BATCH_SIZE`: Number of images to display per round of classification. Default is 10.
- `OUT_OF_GATE_COUNT`: Number of out-of-gate images to include per batch for diversity (0 to `BATCH_SIZE`). Default is 1. Use this to come up with a balanced set of images for your classifier of interest.

In [None]:
GATE_FEATURE = None
GATE_MIN = None
GATE_MAX = None
GATE_MIN_PERCENTILE = None
GATE_MAX_PERCENTILE = None
BATCH_SIZE = 10
OUT_OF_GATE_COUNT = 1

In [None]:
# Prepare dataframes for gating
summary_df, in_gate_df, out_of_gate_df, gate_dbg = prepare_mask_dataframes(
    mode=MODE,
    data_source=data_source,
    plates=PLATES_TO_CLASSIFY,
    wells=WELLS_TO_CLASSIFY,
    keys=_KEYS,
    gate_feature=GATE_FEATURE,
    gate_min=GATE_MIN,
    gate_max=GATE_MAX,
    gate_min_percentile=GATE_MIN_PERCENTILE,
    gate_max_percentile=GATE_MAX_PERCENTILE,
    verbose=True,
)

In [None]:
# Load existing training data if adding
if ADD_TRAINING_DATA and EXISTING_TRAINING_DATA:
    seeded_df, _EXISTING_KEYS = load_existing_training_data(
        EXISTING_TRAINING_DATA, MODE, CLASS_TITLE
    )
    if not RELABEL_CLASSIFICATIONS:
        in_gate_df, out_of_gate_df = filter_existing_from_pools(
            in_gate_df, out_of_gate_df, _EXISTING_KEYS
        )
else:
    seeded_df = None
    _EXISTING_KEYS = set()

print(f"[training] Existing keys loaded: {len(_EXISTING_KEYS)}")

### 1c. <font color='red'>SET PARAMETERS</font>: Display Settings

This section sets how and in what order objects will be displayed by the visualizer.

**Channel visualization:**
- `DISPLAY_CHANNELS`: Channels to display for manual classification.
- `CHANNEL_COLORS`: Colors for each channel. Must align with `DISPLAY_CHANNEL` order. See [matplotlib colors](https://matplotlib.org/stable/gallery/color/named_colors.html).

**Selection method:**
- `TRAINING_DATASET_SELECTION`: Choose `"random"` or `"top_n"`.
  - `"random"`: Randomly select masks from specified plates and wells.
  - `"top_n"`: Ranks tiles by the number of objects in each and selects masks from those tiles.
- `TOP_N`: If using `"top_n"`, specify which ranked tile to use. Ex. `1` would display cells (assuming `MODE="cell"`) from the tile that has the most cells.

**Other settings:**
- `SCALE_BAR`: Scale bar length in pixels. If value exceeds image size, displays as dashed lines.
- `RANDOM_SEED`: Random seed for reproducibility.

In [None]:
DISPLAY_CHANNEL = ["DAPI"]
CHANNEL_COLORS = ["b"]
TRAINING_DATASET_SELECTION = "random"
TOP_N = 0
SCALE_BAR = 30
RANDOM_SEED = 42

In [None]:
# Initialize labeling and resolve channels
CHANNEL_NAMES = config["phenotype"]["channel_names"]
CHANNEL_INDICES = [CHANNEL_NAMES.index(ch) for ch in DISPLAY_CHANNEL]
resolved_colors = resolve_channel_colors(DISPLAY_CHANNEL, CHANNEL_COLORS)

_state = initialize_labeling_state(
    random_seed=RANDOM_SEED,
    mode=MODE,
    class_title=CLASS_TITLE,
    keys=_KEYS,
    existing_classified_df=seeded_df,
)

_render_next_batch(
    state=_state,
    display_channel=DISPLAY_CHANNEL,
    add_training_data=ADD_TRAINING_DATA,
    keys=_KEYS,
    class_title=CLASS_TITLE,
    classification=CLASSIFICATION,
    relabel_classifications=RELABEL_CLASSIFICATIONS,
    training_dataset_selection=TRAINING_DATASET_SELECTION,
    batch_size=BATCH_SIZE,
    summary_df=summary_df,
    in_gate_df=in_gate_df,
    out_of_gate_df=out_of_gate_df,
    out_of_gate_count=OUT_OF_GATE_COUNT,
    channel_indices=CHANNEL_INDICES,
    data_source=data_source,
    channel_names=CHANNEL_NAMES,
    mode=MODE,
    resolved_colors=resolved_colors,
    scale_bar=SCALE_BAR,
    existing_keys=_EXISTING_KEYS,
    gate_feature_present=(GATE_FEATURE is not None),
)

In [None]:
# Consolidate manual classifications into training dataset
consolidated_df, training_dataset_out_path = consolidate_manual_classifications(
    manual_classified_df=_state["manual_classified_df"],
    class_title=CLASS_TITLE,
    mode=MODE,
    data_source=data_source,
    classifier_output_dir=CLASSIFIER_OUTPUT_DIR,
    write=True,
    verbose=True,
)
display(consolidated_df)

## 2. Training

Configure training parameters and train multiple model types to find the best classifier.

### 2a. <font color='red'>SET PARAMETERS</font>: Training Dataset Configuration

- `TRAINING_DATASET_FP`: Path to training dataset(s) in list format. Set to `None` to use the last classified dataset.
- `METADATA_COLS_FP`: Path to save metadata columns for use in aggregate pipeline.
- `METADATA_COLS`: Columns to treat as metadata (not features). Defaults provided, modify as needed based on output above.
- `TRAINING_CHANNELS`: Channels to include in training features.
- `TRAINING_NAME`: Name identifier for this training run.

In [None]:
TRAINING_DATASET_FP = None
METADATA_COLS_FP = "config/cell_data_metadata_cols.tsv"
TRAINING_CHANNELS = None  # Set to ["DAPI", "Cy5"] or subset to filter
TRAINING_OBJECT_TYPES = None  # Set to ["nucleus", "cell", "cytoplasm"] or subset to filter
TRAINING_NAME = None

In [None]:
# These columns will be treated as metadata (not features) if they exist
METADATA_COLS = [
    "label",
    "plate",
    "well",
    "tile",
    "cell_0",
    "i_0",
    "j_0",
    "site",
    "cell_1",
    "i_1",
    "j_1",
    "distance",
    "fov_distance_0",
    "fov_distance_1",
    "cell_barcode_0",
    "gene_symbol_0",
    "cell_barcode_1",
    "gene_symbol_1",
    "mapped_single_gene",
    "channels_min",
    "nucleus_i",
    "nucleus_j",
    "nucleus_bounds_0",
    "nucleus_bounds_1",
    "nucleus_bounds_2",
    "nucleus_bounds_3",
    "cell_i",
    "cell_j",
    "cell_bounds_0",
    "cell_bounds_1",
    "cell_bounds_2",
    "cell_bounds_3",
    "cytoplasm_i",
    "cytoplasm_j",
    "cytoplasm_bounds_0",
    "cytoplasm_bounds_1",
    "cytoplasm_bounds_2",
    "cytoplasm_bounds_3",
]

# Build metadata_cols by checking which candidates exist
metadata_cols = [
    col for col in METADATA_COLS if col in consolidated_df.columns
]

# Always add the classification target column
metadata_cols.append(CLASS_TITLE)

# Save metadata cols to file for use in aggregate pipeline
pd.Series(metadata_cols).to_csv(METADATA_COLS_FP, index=False, header=False, sep="\t")
# User feedback
missing_cols = [
    col for col in METADATA_COLS if col not in consolidated_df.columns
]
print(f"Saved {len(metadata_cols)} metadata columns to {METADATA_COLS_FP}")
print(
    f"Found: {len(metadata_cols) - 1}/{len(METADATA_COLS)} candidate columns"
)
if missing_cols:
    print(f"Missing: {', '.join(missing_cols)}")


In [None]:
# Load training data
if TRAINING_DATASET_FP is not None:
    print("Loading training data from specified file")
    data = load_cellprofiler_data(
        TRAINING_DATASET_FP, class_title=CLASS_TITLE, metadata_cols=METADATA_COLS
    )
else:
    print("Using last classified dataset")
    data = consolidated_df

# Filter features and prepare for training
channel_names = config["phenotype"]["channel_names"]
feature_markers = {c: True for c in channel_names if c in TRAINING_CHANNELS}
exclude_markers = [c for c in channel_names if c not in TRAINING_CHANNELS]

# Calculate excluded object types for display
if TRAINING_OBJECT_TYPES is not None:
    exclude_object_types = [
        obj for obj in ["nucleus", "cell", "cytoplasm", "second_obj"]
        if obj not in TRAINING_OBJECT_TYPES
    ]
else:
    exclude_object_types = []

print(f"Class names: {CLASSIFICATION}")
print(f"Class names to stored numeric values: {class_mapping}")
print(f"Features to train upon: {feature_markers}")
print(f"Features to exclude: {exclude_markers}")
if TRAINING_OBJECT_TYPES is not None:
    print(f"Object types to train upon: {TRAINING_OBJECT_TYPES}")
    print(f"Object types to exclude: {exclude_object_types}")
else:
    print(f"Object types to train upon: All (nucleus, cell, cytoplasm, second_obj)")
print(f"Target column: {CLASS_TITLE}")

### 2b. <font color='red'>SET PARAMETERS</font>: Filter Training Classes

- `REMOVE_MASK_LABELS`: List of class labels to exclude from training (e.g., `unknown`). Set to `None` to keep all classes.

In [None]:
REMOVE_MASK_LABELS = None
class_labels, filtered_class_mapping, class_id = filter_classes(CLASSIFICATION, class_mapping, REMOVE_MASK_LABELS)

In [None]:
# Define model configurations to train and evaluate
model_configs = [
    ('rf_standard', 'rf', 'standard', None),
    ('svc_standard', 'svc', 'standard', None),
    ('xgb_standard', 'xgb', 'standard', None),
    ('lgb_standard', 'lgb', 'standard', None),
    ('xgb_robust', 'xgb', 'robust', None),
    ('xgb_minmax', 'xgb', 'minmax', None),
    ('xgb_none', 'xgb', 'none', None),
    ('xgb_none_var', 'xgb', 'none', {'enhance': True, 'remove_low_variance': True, 'remove_correlated': False, 'select_k_best': None}),
    ('xgb_none_corr', 'xgb', 'none', {'enhance': True, 'remove_low_variance': False, 'remove_correlated': True, 'select_k_best': None}),
    ('xgb_none_kbest100', 'xgb', 'none', {'enhance': True, 'remove_low_variance': False, 'remove_correlated': False, 'select_k_best': 100}),
    ('xgb_none_combined', 'xgb', 'none', {'enhance': True, 'remove_low_variance': True, 'remove_correlated': True, 'select_k_best': 100}),
]

In [None]:
pipeline_result = train_classifier_pipeline(
    data=data,
    class_title=CLASS_TITLE,
    class_id=class_id,
    class_labels=class_labels,
    filtered_class_mapping=filtered_class_mapping,
    metadata_cols=metadata_cols,
    feature_markers=feature_markers,
    exclude_markers=exclude_markers,
    training_name=TRAINING_NAME,
    model_configs=model_configs,
    classifier_output_dir=CLASSIFIER_OUTPUT_DIR,
    training_channels=TRAINING_CHANNELS,
    training_object_types=TRAINING_OBJECT_TYPES,
    verbose=True,
)
multiclass_df = pipeline_result['metrics_df']
print(multiclass_df.head())

# Save the run directory for later use in evaluation
last_training_run_dir = pipeline_result["dirs"]["run"]
print(f"All outputs written to: {last_training_run_dir}")
print(f"\nSaved run directory for evaluation: {Path(last_training_run_dir).name}")

## 3. Selection

Run the cells below to evaluate trained models and select the best one with an appropriate confidence threshold.

### 3a. <font color='red'>SET PARAMETERS</font>: Model Selection Settings

- `MODEL_RUN_DIR`: Path to the model run directory to test (e.g., `"run_20250903_114514"`). Set to `None` to use the last training run.
- `TEST_PLATES`: Plates to use for testing.
- `TEST_WELLS`: Wells to use for testing.

In [None]:
MODEL_RUN_DIR = None  # Set to specific run name (e.g., "run_20251217_103101") to override

TEST_PLATES = ["1"]
TEST_WELLS = ["A1", "A2", "A3"]

# Determine model run directory to use
if MODEL_RUN_DIR is None:
    MODEL_RUN_DIR = get_latest_run_dir(CLASSIFIER_OUTPUT_DIR)
    print(f"Using most recent run directory: {MODEL_RUN_DIR}")
else:
    print(f"Using specified run directory: {MODEL_RUN_DIR}")

In [None]:
plate = TEST_PLATES
well = TEST_WELLS

CLASSIFIER_DIR_PATH = Path(CLASSIFIER_OUTPUT_DIR) / "classifier" / str(MODEL_RUN_DIR)

plates = [str(p) for p in (plate if isinstance(plate, (list, tuple)) else [plate])]
wells = [str(w) for w in (well if isinstance(well, (list, tuple)) else [well])]

# Determine parquet directory based on data source (now consistent with labeling section)
parquet_dir, _ = get_parquet_config(
    mode=MODE,
    source=TRAINING_DATA_SOURCE,
    root_fp=ROOT_FP
)

master_phenotype_df, meta = build_master_phenotype_df(
    plates=plates,
    wells=wells,
    mode=MODE,
    parquet_dir=parquet_dir,
    display_fn=display,
    verbose=True,
)

metadata, features = split_cell_data(master_phenotype_df, METADATA_COLS)
print(metadata.shape, features.shape)

**Note:** The cell below displays evaluation statistics for the selected model run. If no model was selected, the last trained model will be displayed. Based on the accuracy and F1 scores, select a model in the following cell. Your selection will be saved when you run the final cell of this notebook.

In [None]:
_ = display_pngs_in_plots_and_list_models(CLASSIFIER_DIR_PATH, width=1200)

### 3b. <font color='red'>SET PARAMETERS</font>: Model and Montage Settings

- `CLASSIFIER_MODEL`: Name of the model to use (e.g., `"xgb_standard"`). Set to `None` to use the best performing model.
- `MONTAGE_CHANNEL`: Channel to use for montage images. Should be one of the channels used for training.
- `COLLAPSE_COLS`: Columns to collapse on when creating classification summaries (e.g., `["plate", "well"]`).

In [None]:
CLASSIFIER_MODEL = None
MONTAGE_CHANNEL = "DAPI"
COLLAPSE_COLS = ["plate", "well"]

In [None]:
CLASSIFIER_PATH, model_name = resolve_classifier_model_dill_path(CLASSIFIER_DIR_PATH, CLASSIFIER_MODEL)
print("Selected model: " + model_name)
_ = show_model_evaluation_pngs(CLASSIFIER_DIR_PATH, model_name, width=500)

In [None]:
classifier = CellClassifier.load(CLASSIFIER_PATH)
classified_metadata, classified_features = classifier.classify_cells(metadata, features)

print(classified_metadata.head())
CELL_CLASSES = list(classified_metadata[CLASS_TITLE].unique())

print(CLASS_TITLE + " counts:")
print(classified_metadata[CLASS_TITLE].value_counts())

print("\n" + CLASS_TITLE + " confidences:")
classified_metadata[CLASS_TITLE + "_confidence"].hist()
plt.show()

### 3c. <font color='red'>SET PARAMETERS</font>: Confidence Calibration

When your classifier predicts whether an object belongs to a category, it also gives a confidence score. However, these confidence scores can be inaccurate - the model might say "80% confident" when it's actually only correct 60% of the time.

**When to apply post-hoc correction:**
- First, run with `CONFIDENCE_CORRECTION = None` and examine the confidence histogram and rankline UI below
- If high-confidence predictions seem frequently wrong, or the confidence scores don't match reality, enable calibration
- If your confidence scores already look reliable, you can skip calibration

**Set parameters:**
- `CONFIDENCE_CORRECTION`: Set to `"post-hoc"` to enable correction. Defaults to `None` (skip correction).
- `CALIBRATION_DATASET_FP`: Path to additional calibration dataset (only used if correction enabled). Set to `None` to use the training dataset.
- `CALIBRATION_METHOD`: Calibration method to use. Options are `"isotonic"` (recommended, works for most cases) or `"sigmoid"` (best for very small datasets < 100 objects).

**Recommended workflow:**
1. Start with `CONFIDENCE_CORRECTION = None`
2. Run the cells below and examine the confidence distribution
3. Use the rankline UI to manually check if high-confidence predictions match your expectations
4. If confidence scores seem unreliable, come back and set `CONFIDENCE_CORRECTION = "post-hoc"` with `CALIBRATION_METHOD = "isotonic"`

In [None]:
CONFIDENCE_CORRECTION = None
CALIBRATION_DATASET_FP = None
CALIBRATION_METHOD = "isotonic"

In [None]:
if CONFIDENCE_CORRECTION is None:
    print("Skipping confidence calibration")
else:
    if CALIBRATION_DATASET_FP is not None:
        manual_labeled_data = load_cellprofiler_data(CALIBRATION_DATASET_FP)
    else:
        print("No calibration dataset provided, using training dataset")
        CALIBRATION_DATASET_FP = config["classifier"]["training_datasets"]
        if CALIBRATION_DATASET_FP is None or len(CALIBRATION_DATASET_FP) == 0:
            raise ValueError("No calibration dataset provided and no training dataset found in config.")
        manual_labeled_data = load_cellprofiler_data(CALIBRATION_DATASET_FP)

    classified_metadata, meta = calibrate_confidence(
        master_phenotype_df=master_phenotype_df,
        classified_metadata=classified_metadata,
        manual_labeled_data=manual_labeled_data,
        classify_by=MODE,
        class_title=CLASS_TITLE,
        classifier_path=CLASSIFIER_PATH,
        confidence_correction=CONFIDENCE_CORRECTION,
        calibration_method=CALIBRATION_METHOD,
        test_plate=plates,
        test_well=wells,
        min_samples_isotonic=50,
        verbose=True,
    )

In [None]:
fig, axes, montages, titles, ORDERED_CLASSES, summary_df = build_montages_and_summary(
    master_phenotype_df=master_phenotype_df,
    classified_metadata=classified_metadata,
    classify_by=MODE,
    class_mapping=class_mapping,
    class_title=CLASS_TITLE,
    root_fp=ROOT_FP,
    channels=config["phenotype"]["channel_names"],
    montage_channel=MONTAGE_CHANNEL,
    collapse_cols=COLLAPSE_COLS,
    verbose=True,
    show_figure=True,
    display_fn=display,
)

### 3d. <font color='red'>SET PARAMETERS</font>: Rankline UI Settings

- `MINIMUM_DIFFERENCE`: Minimum confidence difference for comparison.

In [None]:
MINIMUM_DIFFERENCE = 0.001

In [None]:
w = launch_rankline_ui(
    classified_metadata=classified_metadata,
    class_title=CLASS_TITLE,
    classify_by=MODE,
    class_mapping=class_mapping,
    data_source=data_source,
    images_source=ROOT_FP / "phenotype",
    channel_names=list(config["phenotype"]["channel_names"]),
    display_channels=DISPLAY_CHANNEL,
    channel_colors=CHANNEL_COLORS,
    test_plate=plates,
    test_well=wells,
    scale_bar_px=SCALE_BAR,
    minimum_difference=MINIMUM_DIFFERENCE,
    thumbnail_px=150,
    auto_display=True,
)

### 3e. <font color='red'>SET PARAMETERS</font>: Confidence Threshold

- `CONFIDENCE_THRESHOLD`: Minimum confidence required for predictions. Cells below this threshold will be filtered out when splitting into classes in the aggregate pipeline. Choose based on the confidence distribution and rankline UI above.

In [None]:
CONFIDENCE_THRESHOLD = 0.5

In [None]:
config["classifier"] = {
    "CLASSIFIER_PATH": str(CLASSIFIER_PATH),
    "CONFIDENCE_THRESHOLD": CONFIDENCE_THRESHOLD,
    "METADATA_COLS_FP": str(METADATA_COLS_FP),
}

with open(CONFIG_FILE_PATH, "w") as config_file:
    config_file.write(CONFIG_FILE_HEADER)
    yaml.safe_dump(config, config_file, default_flow_style=False, sort_keys=False)

print("Saved classifier settings to config.")