# Train Classifier

This notebook provides tools for training custom cell or vacuole classifiers. It covers labeling data, training models, and selecting the best classifier.

Cells marked with <font color='red'>SET PARAMETERS</font> contain crucial variables that need to be set according to your specific experimental setup and data organization.
Please review and modify these variables as needed before proceeding with the analysis.

## <font color='red'>SET PARAMETERS</font>

### Fixed parameters for aggregate module

- `CONFIG_FILE_PATH`: Path to a Brieflow config file used during processing. Absolute or relative to where workflows are run from.

In [None]:
CONFIG_FILE_PATH = "config/config.yml"

In [None]:
import re
from datetime import datetime
from pathlib import Path

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import yaml

from lib.aggregate.cell_classification import CellClassifier
from lib.aggregate.cell_data_utils import split_cell_data
from lib.classify.apply import (
    build_master_phenotype_df,
    build_montages_and_summary,
    display_pngs_in_plots_and_list_models,
    launch_rankline_ui,
    resolve_classifier_model_dill_path,
    show_model_evaluation_pngs,
)
from lib.classify.calibration import calibrate_confidence
from lib.classify.labeling import (
    _ensure_mc_schema,
    _normalize_keys,
    _pq_path_for,
    _render_next_batch,
    build_class_mapping,
    consolidate_manual_classifications,
    filter_existing_from_pools,
    initialize_labeling_state,
    load_existing_training_data,
    prepare_mask_dataframes,
    resolve_channel_colors,
    to_list_of_str,
)
from lib.classify.train import (
    filter_classes,
    load_cellprofiler_data,
    train_classifier_pipeline,
)
from lib.phenotype.constants import DEFAULT_METADATA_COLS
from lib.shared.configuration_utils import CONFIG_FILE_HEADER
from lib.shared.file_utils import get_filename

_KEYS = ["plate", "well", "tile", "mask_label"]

In [None]:
with open(CONFIG_FILE_PATH, "r") as config_file:
    config = yaml.safe_load(config_file)

ROOT_FP = Path(config["all"]["root_fp"])
PHENOTYPE_OUTPUT_FP = ROOT_FP / "phenotype"
MERGE_OUTPUT_FP = ROOT_FP / "merge"
CLASSIFIER_OUTPUT_DIR = ROOT_FP / "classifier"
CLASSIFIER_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

## 1. Labeling

Use an interactive UI to label cells or vacuoles to create training data for machine learning models.

### 1a. <font color='red'>SET PARAMETERS</font>: Classification Settings

**If adding to an existing training dataset:**
- `ADD_TRAINING_DATA`: Set to `False` for first-time training, `True` to add to an existing dataset.
- `EXISTING_TRAINING_DATA`: Only set if `ADD_TRAINING_DATA` is `True`. Specify the filename of the existing training dataset.
- `RELABEL_CLASSIFICATIONS`: Set to `True` to revisit and modify labels from the existing dataset. Previously labeled data will be shown first.

**Classification parameters:**
- `TRAINING_DATA_SOURCE`: Source of training data - `"phenotype"` or `"merge"`.
- `MODE`: Classification mode - `"cell"` (default) or `"vacuole"`.
- `CLASS_TITLE`: Name of the new column added to the phenotype dataframe.
- `CLASSIFICATION`: List of categories for classification. Categories appear as 1, 2, 3... in output, corresponding to list order.
- `PLATES_TO_CLASSIFY`: List of plates to include in classification.
- `WELLS_TO_CLASSIFY`: List of wells among the specified plates to classify.

In [None]:
ADD_TRAINING_DATA = False
EXISTING_TRAINING_DATA = None
RELABEL_CLASSIFICATIONS = True

TRAINING_DATA_SOURCE = "merge"
MODE = "cell"
CLASS_TITLE = "cell_stage"
CLASSIFICATION = ["Mitotic", "Interphase"]
PLATES_TO_CLASSIFY = [1]
WELLS_TO_CLASSIFY = ["A1"]

class_mapping = build_class_mapping(CLASSIFICATION)

In [None]:
# Print all columns in the data source to see what's available

plate_set = [str(p) for p in PLATES_TO_CLASSIFY]
well_set = list(WELLS_TO_CLASSIFY)

# Determine parquet directory based on data source
if TRAINING_DATA_SOURCE == "merge":
    parquet_dir = ROOT_FP / "merge" / "parquets"
    name_suffix = "merge_final"
else:  # phenotype
    parquet_dir = PHENOTYPE_OUTPUT_FP / "parquets"
    name_suffix = "phenotype_cp" if MODE == "cell" else "phenotype_vacuoles"

all_cols = []
for plate in sorted(plate_set):
    for well in sorted(well_set):
        pq_path = parquet_dir / get_filename(
            {"plate": int(plate), "well": well}, name_suffix, "parquet"
        )
        if not pq_path.exists():
            print(f"[warn] Skipping missing: {pq_path}")
            continue
        try:
            cols = pq.ParquetFile(pq_path).schema.names
        except Exception:
            cols = list(pd.read_parquet(pq_path).head(0).columns)
        
        if not all_cols:
            all_cols = cols
        break
    if all_cols:
        break

print(f"Data source: {TRAINING_DATA_SOURCE}")
print(f"Found {len(all_cols)} columns:\n")
for col in all_cols:
    print(col)

### 1b. <font color='red'>SET PARAMETERS</font>: Gating Settings

**Feature gating (optional):**
Use gating to prioritize cells with specific characteristics for labeling. For example, to find mitotic cells, one might set `GATE_FEATURE = "nucleus_DAPI_mad"` with a high `GATE_MIN_PERCENTILE` to prioritize cells with bright nuclei.

- `GATE_FEATURE`: Feature to gate by (e.g., `"nucleus_DAPI_mad"`). Set to `None` to skip gating.
- `GATE_MIN`: Minimum value (exclusive). Cells below this will be deprioritized.
- `GATE_MAX`: Maximum value (exclusive). Cells above this will be deprioritized.
- `GATE_MIN_PERCENTILE`: Percentile minimum (0-1). Cells below this percentile will be deprioritized.
- `GATE_MAX_PERCENTILE`: Percentile maximum (0-1). Cells above this percentile will be deprioritized.

**Batch settings:**
- `BATCH_SIZE`: Number of images to display per round of classification. Default is 10.
- `OUT_OF_GATE_COUNT`: Number of out-of-gate images to include per batch for diversity (0 to `BATCH_SIZE`). Default is 1. Use this to come up with a balanced set of images for your classifier of interest.

In [None]:
GATE_FEATURE = None
GATE_MIN = None
GATE_MAX = None
GATE_MIN_PERCENTILE = None
GATE_MAX_PERCENTILE = None
BATCH_SIZE = 10
OUT_OF_GATE_COUNT = 1

In [None]:
summary_df, in_gate_df, out_of_gate_df, gate_dbg = prepare_mask_dataframes(
    mode=MODE,
    phenotype_fp=PHENOTYPE_OUTPUT_FP,
    plates=PLATES_TO_CLASSIFY,
    wells=WELLS_TO_CLASSIFY,
    keys=_KEYS,
    gate_feature=GATE_FEATURE,
    gate_min=GATE_MIN,
    gate_max=GATE_MAX,
    gate_min_percentile=GATE_MIN_PERCENTILE,
    gate_max_percentile=GATE_MAX_PERCENTILE,
    verbose=True
)

In [None]:
# Load existing training data if adding to it
if ADD_TRAINING_DATA and EXISTING_TRAINING_DATA:
    seeded_df, _EXISTING_KEYS = load_existing_training_data(
        EXISTING_TRAINING_DATA, MODE, CLASS_TITLE
    )
    if not RELABEL_CLASSIFICATIONS:
        in_gate_df, out_of_gate_df = filter_existing_from_pools(
            in_gate_df, out_of_gate_df, _EXISTING_KEYS
        )
else:
    seeded_df = None
    _EXISTING_KEYS = set()

print(f"[training] Existing keys loaded: {len(_EXISTING_KEYS)}")

### 1c. <font color='red'>SET PARAMETERS</font>: Display Settings

**Channel visualization:**
- `DISPLAY_CHANNELS`: Channels to display for manual classification.
- `CHANNEL_COLORS`: Colors for each channel. Must align with `DISPLAY_CHANNEL` order. See [matplotlib colors](https://matplotlib.org/stable/gallery/color/named_colors.html).

**Selection method:**
- `TRAINING_DATASET_SELECTION`: Choose `"random"` or `"top_n"`.
  - `"random"`: Randomly select masks from specified plates and wells.
  - `"top_n"`: Select masks from tiles with the most instances. Useful for checking alignment with original images.
- `TOP_N`: If using `"top_n"`, specify which ranked tile to use.

**Other settings:**
- `SCALE_BAR`: Scale bar length in pixels. If value exceeds image size, displays as dashed lines.
- `RANDOM_SEED`: Random seed for reproducibility.

In [None]:
DISPLAY_CHANNEL = ["DAPI"]
CHANNEL_COLORS = ["b"]
TRAINING_DATASET_SELECTION = "random"
TOP_N = 0
SCALE_BAR = 30
RANDOM_SEED = 42

In [None]:
# Initialize state and resolve channels
CHANNEL_NAMES = config["phenotype"]["channel_names"]
CHANNEL_INDICES = [CHANNEL_NAMES.index(ch) for ch in DISPLAY_CHANNEL]
resolved_colors = resolve_channel_colors(DISPLAY_CHANNEL, CHANNEL_COLORS)

_STATE = initialize_labeling_state(
    random_seed=RANDOM_SEED,
    mode=MODE,
    class_title=CLASS_TITLE,
    keys=_KEYS,
    existing_classified_df=seeded_df,
)

_render_next_batch(
    state=_STATE,
    DISPLAY_CHANNEL=DISPLAY_CHANNEL,
    ADD_TRAINING_DATA=ADD_TRAINING_DATA,
    keys=_KEYS,
    CLASS_TITLE=CLASS_TITLE,
    CLASSIFICATION=CLASSIFICATION,
    RELABEL_CLASSIFICATIONS=RELABEL_CLASSIFICATIONS,
    TRAINING_DATASET_SELECTION=TRAINING_DATASET_SELECTION,
    BATCH_SIZE=BATCH_SIZE,
    summary_df=summary_df,
    in_gate_df=in_gate_df,
    out_of_gate_df=out_of_gate_df,
    OUT_OF_GATE_COUNT=OUT_OF_GATE_COUNT,
    CHANNEL_INDICES=CHANNEL_INDICES,
    PHENOTYPE_OUTPUT_FP=PHENOTYPE_OUTPUT_FP,
    CHANNEL_NAMES=CHANNEL_NAMES,
    MODE=MODE,
    RESOLVED_COLORS=resolved_colors,
    SCALE_BAR=SCALE_BAR,
    EXISTING_KEYS=_EXISTING_KEYS,
    GATE_FEATURE_PRESENT=(GATE_FEATURE is not None),
)

In [None]:
consolidated_df, training_dataset_out_path = consolidate_manual_classifications(
    manual_classified_df=_STATE["manual_classified_df"],
    class_title=CLASS_TITLE,
    mode=MODE,
    phenotype_output_fp=PHENOTYPE_OUTPUT_FP,
    classifier_output_dir=CLASSIFIER_OUTPUT_DIR,
    write=True,
    verbose=True,
)
display(consolidated_df)

## 2. Training

Configure training parameters and train multiple model types to find the best classifier.

### 2a. <font color='red'>SET PARAMETERS</font>: Training Dataset Configuration

- `TRAINING_DATASET_FP`: Path to training dataset(s) in list format. Set to `None` to use the last classified dataset.
- `METADATA_COLS_FP`: Path to save metadata columns for use in aggregate pipeline.
- `METADATA_COLS`: Columns to treat as metadata (not features). Defaults provided, modify as needed based on output above.
- `TRAINING_CHANNELS`: Channels to include in training features.
- `TRAINING_NAME`: Name identifier for this training run.

In [None]:
TRAINING_DATASET_FP = None
METADATA_COLS_FP = "config/cell_data_metadata_cols.tsv"

# Set metadata columns based on output above
METADATA_COLS = [
    "plate",
    "well",
    "tile",
    "cell_0",
    "i_0",
    "j_0",
    "site",
    "cell_1",
    "i_1",
    "j_1",
    "distance",
    "fov_distance_0",
    "fov_distance_1",
    "cell_barcode_0",
    "gene_symbol_0",
    "mapped_single_gene",
    "channels_min",
    "nucleus_i",
    "nucleus_j",
    "nucleus_bounds_0",
    "nucleus_bounds_1",
    "nucleus_bounds_2",
    "nucleus_bounds_3",
    "cell_i",
    "cell_j",
    "cell_bounds_0",
    "cell_bounds_1",
    "cell_bounds_2",
    "cell_bounds_3",
    "cytoplasm_i",
    "cytoplasm_j",
    "cytoplasm_bounds_0",
    "cytoplasm_bounds_1",
    "cytoplasm_bounds_2",
    "cytoplasm_bounds_3",
]

TRAINING_CHANNELS = ['DAPI']
TRAINING_NAME = "cell_classifier"

# Save metadata cols to file for use in aggregate pipeline
pd.Series(METADATA_COLS).to_csv(METADATA_COLS_FP, index=False, header=False, sep="\t")
print(f"Saved {len(METADATA_COLS)} metadata columns to {METADATA_COLS_FP}")

In [None]:
if TRAINING_DATASET_FP is not None:
    data = load_cellprofiler_data(TRAINING_DATASET_FP)
else:
    print("No training dataset provided, using last classified dataset")
    data = consolidated_df

channel_names = config["phenotype"]["channel_names"]
feature_markers = {c: True for c in channel_names if c in TRAINING_CHANNELS}
exclude_markers = [c for c in channel_names if c not in TRAINING_CHANNELS]

print(f"Class names: {CLASSIFICATION}")
print(f"Class names to stored numeric values: {class_mapping}")
print(f"Features to train upon: {feature_markers}")
print(f"Features to exclude: {exclude_markers}")
print(f"Target column: {CLASS_TITLE}")

### 2b. <font color='red'>SET PARAMETERS</font>: Filter Training Classes

- `REMOVE_MASK_LABELS`: List of class labels to exclude from training (e.g., `unknown`). Set to `None` to keep all classes.

In [None]:
REMOVE_MASK_LABELS = None
class_labels, filtered_class_mapping, class_id = filter_classes(CLASSIFICATION, class_mapping, REMOVE_MASK_LABELS)

In [None]:
model_configs = [
    ('rf_standard', 'rf', 'standard', None),
    ('svc_standard', 'svc', 'standard', None),
    ('xgb_standard', 'xgb', 'standard', None),
    ('lgb_standard', 'lgb', 'standard', None),
    ('xgb_robust', 'xgb', 'robust', None),
    ('xgb_minmax', 'xgb', 'minmax', None),
    ('xgb_none', 'xgb', 'none', None),
    ('xgb_none_var', 'xgb', 'none', {'enhance': True, 'remove_low_variance': True, 'remove_correlated': False, 'select_k_best': None}),
    ('xgb_none_corr', 'xgb', 'none', {'enhance': True, 'remove_low_variance': False, 'remove_correlated': True, 'select_k_best': None}),
    ('xgb_none_kbest100', 'xgb', 'none', {'enhance': True, 'remove_low_variance': False, 'remove_correlated': False, 'select_k_best': 100}),
    ('xgb_none_combined', 'xgb', 'none', {'enhance': True, 'remove_low_variance': True, 'remove_correlated': True, 'select_k_best': 100}),
]

In [None]:
pipeline_result = train_classifier_pipeline(
    data=data,
    class_title=CLASS_TITLE,
    class_id=class_id,
    class_labels=class_labels,
    filtered_class_mapping=filtered_class_mapping,
    metadata_cols=METADATA_COLS,
    feature_markers=feature_markers,
    exclude_markers=exclude_markers,
    training_name=TRAINING_NAME,
    model_configs=model_configs,
    classifier_output_dir=CLASSIFIER_OUTPUT_DIR,
    training_channels=TRAINING_CHANNELS,
    verbose=True,
)
multiclass_df = pipeline_result['metrics_df']
print(multiclass_df.head())

run_dir = pipeline_result["dirs"]["run"]
print(f"All outputs written to: {run_dir}")

## 3. Selection

Run the cells below to evaluate trained models and select the best one with an appropriate confidence threshold.

### 3a. <font color='red'>SET PARAMETERS</font>: Model Selection Settings

- `MODEL_RUN_DIR`: Path to the model run directory to test (e.g., `"run_20250903_114514"`). Set to `None` to use the last training run.
- `TEST_PLATES`: Plates to use for testing.
- `TEST_WELLS`: Wells to use for testing.

In [None]:
# Auto-generate run directory name (override if needed)
MODEL_RUN_DIR = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
TEST_PLATES = ["1", "2"]
TEST_WELLS = ["A1", "A2", "A3"]

In [None]:
plate = TEST_PLATES
well = TEST_WELLS

CLASSIFIER_DIR_PATH = Path(CLASSIFIER_OUTPUT_DIR) / "classifier" / str(MODEL_RUN_DIR)

plates = [str(p) for p in (plate if isinstance(plate, (list, tuple)) else [plate])]
wells = [str(w) for w in (well if isinstance(well, (list, tuple)) else [well])]

if MODE == "cell":
    name_suffix = "phenotype_cp.parquet"
elif MODE == "vacuole":
    name_suffix = "phenotype_vacuoles.parquet"

master_phenotype_df, meta = build_master_phenotype_df(
    plates=plates,
    wells=wells,
    name_suffix=name_suffix,
    parquet_dir=PHENOTYPE_OUTPUT_FP / "parquets",
    display_fn=display,
    verbose=True
)

metadata, features = split_cell_data(master_phenotype_df, METADATA_COLS)
print(metadata.shape, features.shape)

**Note:** The cell below displays evaluation statistics for the selected model run. If no model was selected, the last trained model will be displayed. Based on the accuracy and F1 scores, select a model in the following cell. Your selection will be saved when you run the final cell of this notebook.

In [None]:
_=display_pngs_in_plots_and_list_models(
    CLASSIFIER_DIR_PATH, width=1200
)

### 3b. <font color='red'>SET PARAMETERS</font>: Model and Montage Settings

- `CLASSIFIER_MODEL`: Name of the model to use (e.g., `"xgb_standard"`). Set to `None` to use the best performing model.
- `MONTAGE_CHANNEL`: Channel to use for montage images. Should be one of the channels used for training.
- `COLLAPSE_COLS`: Columns to collapse on when creating classification summaries (e.g., `["plate", "well"]`).

In [None]:
CLASSIFIER_MODEL = None
MONTAGE_CHANNEL = "DAPI"
COLLAPSE_COLS = ["plate", "well"]

In [None]:
CLASSIFIER_PATH, model_name = resolve_classifier_model_dill_path(CLASSIFIER_DIR_PATH, CLASSIFIER_MODEL)
print("Selected model: " + model_name)
_ = show_model_evaluation_pngs(CLASSIFIER_DIR_PATH, model_name, width=500)

In [None]:
classifier = CellClassifier.load(CLASSIFIER_PATH)
classified_metadata, classified_features = classifier.classify_cells(metadata, features)

print(classified_metadata.head())
CELL_CLASSES = list(classified_metadata[CLASS_TITLE].unique())

print(CLASS_TITLE + " counts:")
print(classified_metadata[CLASS_TITLE].value_counts())

print("\n" + CLASS_TITLE + " confidences:")
classified_metadata[CLASS_TITLE + "_confidence"].hist()
plt.show()

### 3c. <font color='red'>SET PARAMETERS</font>: Confidence Calibration

- `CONFIDENCE_CORRECTION`: Set to `None` or `"post-hoc"` for post-hoc confidence correction.
- `CALIBRATION_DATASET_FP`: Path to calibration dataset. Set to `None` to use the training dataset.
- `CALIBRATION_METHOD`: Calibration method to use. Options are `"isotonic"` or `"sigmoid"`.

In [None]:
CONFIDENCE_CORRECTION = None
CALIBRATION_DATASET_FP = None
CALIBRATION_METHOD = "isotonic"

In [None]:
if CALIBRATION_DATASET_FP is not None:
    manual_labeled_data = load_cellprofiler_data(CALIBRATION_DATASET_FP)
else:
    print("No calibration dataset provided, using training dataset")
    CALIBRATION_DATASET_FP = config["classify"]["training_datasets"]
    if CALIBRATION_DATASET_FP is None or len(CALIBRATION_DATASET_FP) == 0:
        raise ValueError("No calibration dataset provided and no training dataset found in config.")
    manual_labeled_data = load_cellprofiler_data(CALIBRATION_DATASET_FP)

classified_metadata, meta = calibrate_confidence(
    master_phenotype_df=master_phenotype_df,
    classified_metadata=classified_metadata,
    manual_labeled_data=manual_labeled_data,
    classify_by=MODE,
    class_title=CLASS_TITLE,
    classifier_path=CLASSIFIER_PATH,
    confidence_correction=CONFIDENCE_CORRECTION,
    calibration_method=CALIBRATION_METHOD,
    test_plate=plates,
    test_well=wells,
    min_samples_isotonic=50,
    verbose=True,
)

print(meta)

In [None]:
fig, axes, montages, titles, ORDERED_CLASSES, summary_df = build_montages_and_summary(
    master_phenotype_df=master_phenotype_df,
    classified_metadata=classified_metadata,
    classify_by=MODE,
    class_mapping=class_mapping,
    class_title=CLASS_TITLE,
    root_fp=ROOT_FP,
    channels=config["phenotype"]["channel_names"],
    montage_channel=MONTAGE_CHANNEL,
    collapse_cols=COLLAPSE_COLS,
    verbose=True,
    show_figure=True,
    display_fn=display,
)

### 3d. <font color='red'>SET PARAMETERS</font>: Rankline UI Settings

- `MINIMUM_DIFFERENCE`: Minimum confidence difference for comparison.

In [None]:
MINIMUM_DIFFERENCE = 0.001

In [None]:
w = launch_rankline_ui(
    classified_metadata=classified_metadata,
    class_title=CLASS_TITLE,
    classify_by=MODE,
    class_mapping=class_mapping,
    phenotype_output_fp=PHENOTYPE_OUTPUT_FP,
    channel_names=list(config["phenotype"]["channel_names"]),
    display_channels=DISPLAY_CHANNEL,
    channel_colors=CHANNEL_COLORS,
    test_plate=plates,
    test_well=wells,
    scale_bar_px=SCALE_BAR,
    minimum_difference=MINIMUM_DIFFERENCE,
    thumbnail_px=150,
    auto_display=True,
)

### 3e. <font color='red'>SET PARAMETERS</font>: Confidence Threshold

- `CONFIDENCE_THRESHOLD`: Minimum confidence required for predictions. Cells below this threshold will be filtered out when splitting into classes in the aggregate pipeline. Choose based on the confidence distribution and rankline UI above.

In [None]:
CONFIDENCE_THRESHOLD = 0.5

In [None]:
config["classify"] = {
    "CLASSIFIER_PATH": str(CLASSIFIER_PATH),
    "CONFIDENCE_THRESHOLD": CONFIDENCE_THRESHOLD,
    "METADATA_COLS_FP": str(METADATA_COLS_FP),
}

with open(CONFIG_FILE_PATH, "w") as config_file:
    config_file.write(CONFIG_FILE_HEADER)
    yaml.safe_dump(config, config_file, default_flow_style=False, sort_keys=False)

print("Saved classify settings to config.")