# Train Classifier

This notebook provides tools for training custom cell or vacuole classifiers. It covers labeling data, training models, and selecting the best classifier.

Cells marked with <font color='red'>SET PARAMETERS</font> contain crucial variables that need to be set according to your specific experimental setup and data organization.
Please review and modify these variables as needed before proceeding with the analysis.

## <font color='red'>SET PARAMETERS</font>

### Fixed parameters for aggregate module

- `CONFIG_FILE_PATH`: Path to a Brieflow config file used during processing. Absolute or relative to where workflows are run from.

In [None]:
CONFIG_FILE_PATH = "config/config.yml"

In [None]:
import os
import re
from pathlib import Path
from typing import List

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml

from lib.aggregate.cell_classification import CellClassifier
from lib.aggregate.cell_data_utils import split_cell_data
from lib.classify.apply import (
    build_master_phenotype_df,
    build_montages_and_summary,
    display_pngs_in_plots_and_list_models,
    launch_rankline_ui,
    resolve_classifier_model_dill_path,
    show_model_evaluation_pngs,
)
from lib.classify.calibration import calibrate_confidence
from lib.classify.labeling import (
    _ensure_mc_schema,
    _mode_norm,
    _normalize_keys,
    _pq_path_for,
    _render_next_batch,
    consolidate_manual_classifications,
    prepare_mask_dataframes,
)
from lib.classify.train import (
    filter_classes,
    load_cellprofiler_data,
    train_classifier_pipeline,
)

In [None]:
with open(CONFIG_FILE_PATH, "r") as config_file:
    config = yaml.safe_load(config_file)

ROOT_FP = (Path(CONFIG_FILE_PATH).resolve().parent.parent / config["all"]["root_fp"]).resolve()
PHENOTYPE_OUTPUT_FP = ROOT_FP / "phenotype"
PQ_DIR = PHENOTYPE_OUTPUT_FP / "parquets"

CLASSIFIER_OUTPUT_DIR = ROOT_FP / "classifier"
CLASSIFIER_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Created/verified: {CLASSIFIER_OUTPUT_DIR}")

## 1. Labeling

Use an interactive UI to label cells or vacuoles to create training data for machine learning models.

### 1a. <font color='red'>SET PARAMETERS</font>: Classification Settings

**If adding to an existing training dataset:**
- `ADD_TRAINING_DATA`: Set to `False` for first-time training, `True` to add to an existing dataset.
- `EXISTING_TRAINING_DATA`: Only set if `ADD_TRAINING_DATA` is `True`. Specify the filename of the existing training dataset.
- `RELABEL_CLASSIFICATIONS`: Set to `True` to revisit and modify labels from the existing dataset. Previously labeled data will be shown first.

**Classification parameters:**
- `CLASSIFY_BY_VACUOLE_OR_CELL`: Set to `"vacuole"` or `"cell"`.
- `CLASS_TITLE`: Name of the new column added to the phenotype dataframe.
- `CLASSIFICATION`: List of categories for classification. Categories appear as 1, 2, 3... in output, corresponding to list order.
- `PLATES_TO_CLASSIFY`: Plates to include in classification.
- `WELLS_TO_CLASSIFY`: Wells among the specified plates to classify.

In [None]:
ADD_TRAINING_DATA = True
EXISTING_TRAINING_DATA = "vacuole_classifier_training_dataset_for_parasite_count_20250818_002808.parquet"
RELABEL_CLASSIFICATIONS = True
CLASSIFY_BY_VACUOLE_OR_CELL = "vacuole"
CLASS_TITLE = "parasite_count"
CLASSIFICATION = ["1 parasite", "2-3 parasite", "4-7 parasite", "8+ parasite"]
PLATES_TO_CLASSIFY = ["1", "2"] 
WELLS_TO_CLASSIFY = ["A1", "A2", "A3"] 

In [None]:
_mode_cfg = str(CLASSIFY_BY_VACUOLE_OR_CELL).strip().lower()
if _mode_cfg not in {"cell", "vacuole"}:
    raise ValueError(f"CLASSIFY_BY_VACUOLE_OR_CELL must be 'cell' or 'vacuole', got: {CLASSIFY_BY_VACUOLE_OR_CELL!r}")

EXISTING_TRAINING_PATH = None
if ADD_TRAINING_DATA:
    if not EXISTING_TRAINING_DATA:
        raise ValueError("ADD_TRAINING_DATA=True but EXISTING_TRAINING_DATA is not provided (filename).")
    train_dir = Path(CLASSIFIER_OUTPUT_DIR) / "training_dataset"
    EXISTING_TRAINING_PATH = train_dir / str(EXISTING_TRAINING_DATA)
    if not EXISTING_TRAINING_PATH.exists():
        raise FileNotFoundError(f"Existing training parquet not found: {EXISTING_TRAINING_PATH}")

    m = re.match(r"^(cell|vacuole)_classifier_training_dataset_for_(.+?)_\d{8}_\d{6}\.parquet$",
                 EXISTING_TRAINING_PATH.name)
    if not m:
        raise ValueError(
            "EXISTING_TRAINING_DATA filename must match:\n"
            "  '{mode}_classifier_training_dataset_for_{CLASS_TITLE}_{YYYYMMDD}_{HHMMSS}.parquet'"
        )
    mode_from_file = m.group(1)
    class_title_from_file = m.group(2)

    if mode_from_file != _mode_cfg:
        raise ValueError(
            f"Mode mismatch: file says mode='{mode_from_file}', but CLASSIFY_BY_VACUOLE_OR_CELL='{_mode_cfg}'. "
            "Please align your settings with the existing parquet."
        )
    if class_title_from_file != CLASS_TITLE:
        raise ValueError(
            f"CLASS_TITLE mismatch: file says CLASS_TITLE='{class_title_from_file}', but CLASS_TITLE='{CLASS_TITLE}'. "
            "Please align your settings with the existing parquet."
        )
    
class_mapping = {'label_to_class': {i + 1: label for i, label in enumerate(CLASSIFICATION)}}

print("[training] Config OK.",
      f"ADD_TRAINING_DATA={ADD_TRAINING_DATA}, RELABEL_CLASSIFICATIONS={RELABEL_CLASSIFICATIONS}")

In [None]:
mode = _mode_norm(CLASSIFY_BY_VACUOLE_OR_CELL)
plate_set = [str(p) for p in PLATES_TO_CLASSIFY]
well_set = list(WELLS_TO_CLASSIFY)

exclude_exact = {"plate", "well", "tile", "vacuole_id", "label"}
ordered_features, seen = [], set()

for plate in sorted(plate_set):
    for well in sorted(well_set):
        pq_path = _pq_path_for(plate, well, PQ_DIR, mode)
        if not pq_path.exists():
            continue
        try:
            import pyarrow.parquet as pq
            cols = pq.ParquetFile(pq_path).schema.names
        except Exception:
            cols = list(pd.read_parquet(pq_path).head(0).columns)

        for col in cols:
            lc = str(col).lower()
            if (
                lc in exclude_exact
                or ("bounds" in lc)
                or ("location" in lc)
                or re.fullmatch(r"__index_level_\d+__", lc)
            ):
                continue
            if col not in seen:
                seen.add(col)
                ordered_features.append(col)

for name in ordered_features:
    print(name)

### 1b. <font color='red'>SET PARAMETERS</font>: Threshold Settings

**Feature thresholding (optional):**
- `THRESHOLD_FEATURE`: Feature to threshold by (e.g., `"nucleus_DAPI_mean"`). Set to `None` to skip thresholding.
- `THRESHOLD_MIN`: Minimum value (exclusive). Masks below this will not be displayed.
- `THRESHOLD_MAX`: Maximum value (inclusive). Masks above this will not be displayed.
- `THRESHOLD_MIN_PERCENTILE`: Percentile minimum (0-1). Masks below this percentile will not be displayed.
- `THRESHOLD_MAX_PERCENTILE`: Percentile maximum (0-1). Masks above this percentile will not be displayed.

**Batch settings:**
- `BATCH_SIZE`: Number of images to display per round of classification. Default is 10.
- `OUT_OF_THRESHOLD_RANDOMIZER`: Number of out-of-threshold images to include per batch (0 to `BATCH_SIZE`). Default is 0.

In [None]:
THRESHOLD_FEATURE = "vacuole_area_x"
THRESHOLD_MIN = None
THRESHOLD_MAX = 11600
THRESHOLD_MIN_PERCENTILE = 0.5
THRESHOLD_MAX_PERCENTILE = None
BATCH_SIZE = 10
OUT_OF_THRESHOLD_RANDOMIZER = 1

In [None]:
_KEYS = ["plate", "well", "tile", "mask_label"]

mask_summary_df, mask_instances_df, mask_instances_out_of_threshold_df, thr_dbg = prepare_mask_dataframes(
    mode=CLASSIFY_BY_VACUOLE_OR_CELL,
    pq_root=PHENOTYPE_OUTPUT_FP,
    plates=PLATES_TO_CLASSIFY,
    wells=WELLS_TO_CLASSIFY,
    keys=_KEYS,
    threshold_feature=THRESHOLD_FEATURE,
    threshold_min=THRESHOLD_MIN,
    threshold_max=THRESHOLD_MAX,
    threshold_min_percentile=THRESHOLD_MIN_PERCENTILE,
    threshold_max_percentile=THRESHOLD_MAX_PERCENTILE,
    verbose=True
)

In [None]:
if ADD_TRAINING_DATA:
    df_existing = pd.read_parquet(EXISTING_TRAINING_PATH)
    seeded = _normalize_keys(df_existing, _mode_cfg, CLASS_TITLE)
    seeded["_existing"] = True

    if "_STATE" in globals():
        if _STATE.get("manual_classified_df") is None:
            _STATE["manual_classified_df"] = seeded.copy()
        else:
            _STATE["manual_classified_df"] = (
                pd.concat([_STATE["manual_classified_df"], seeded], ignore_index=True)
                  .drop_duplicates(subset=["plate","well","tile","mask_label"], keep="last")
            )
        if _STATE.get("manual_unclassified_df") is None:
            _STATE["manual_unclassified_df"] = pd.DataFrame(columns=["plate","well","tile","mask_label"])
    else:
        manual_classified_df = seeded.copy()
        manual_unclassified_df = pd.DataFrame(columns=["plate","well","tile","mask_label"])

    _EXISTING_KEYS = set((int(r.plate), str(r.well), int(r.tile), int(r.mask_label))
                         for r in seeded.itertuples(index=False))
else:
    _EXISTING_KEYS = set()

if ADD_TRAINING_DATA and not RELABEL_CLASSIFICATIONS:
    if not mask_instances_df.empty:
        mask_instances_df = mask_instances_df.merge(
            pd.DataFrame(list(_EXISTING_KEYS), columns=["plate","well","tile","mask_label"])
              .assign(_ex=1),
            on=["plate","well","tile","mask_label"], how="left"
        )
        mask_instances_df = mask_instances_df[mask_instances_df["_ex"].isna()].drop(columns="_ex").reset_index(drop=True)
    if not mask_instances_out_of_threshold_df.empty:
        mask_instances_out_of_threshold_df = mask_instances_out_of_threshold_df.merge(
            pd.DataFrame(list(_EXISTING_KEYS), columns=["plate","well","tile","mask_label"])
              .assign(_ex=1),
            on=["plate","well","tile","mask_label"], how="left"
        )
        mask_instances_out_of_threshold_df = mask_instances_out_of_threshold_df[
            mask_instances_out_of_threshold_df["_ex"].isna()
        ].drop(columns="_ex").reset_index(drop=True)

print("[training] Existing keys in session:", len(_EXISTING_KEYS))

### 1c. <font color='red'>SET PARAMETERS</font>: Display Settings

**Channel visualization:**
- `DISPLAY_CHANNEL`: Channels to display for manual classification.
- `CHANNEL_COLORS`: Colors for each channel. Must align with `DISPLAY_CHANNEL` order. See [matplotlib colors](https://matplotlib.org/stable/gallery/color/named_colors.html).

**Selection method:**
- `TRAINING_DATASET_SELECTION`: Choose `"random"` or `"top_n"`.
  - `"random"`: Randomly select masks from specified plates and wells.
  - `"top_n"`: Select masks from tiles with the most instances. Useful for checking alignment with original images.
- `TOP_N`: If using `"top_n"`, specify which ranked tile to use.

**Other settings:**
- `SCALE_BAR`: Scale bar length in pixels. If value exceeds image size, displays as dashed lines.
- `RANDOM_SEED`: Random seed for reproducibility.

In [None]:
DISPLAY_CHANNEL = ["CDPK1", "DAPI"]
CHANNEL_COLORS = ["r", "c"]
TRAINING_DATASET_SELECTION = "random"
TOP_N = 0
SCALE_BAR = 30
RANDOM_SEED = 42

In [None]:
_KEYS = ["plate", "well", "tile", "mask_label"]

try:
    _STATE
except NameError:
    _STATE = {}

_STATE.setdefault("rng", np.random.default_rng(RANDOM_SEED))
_STATE.setdefault("aligned_cache", {})
_STATE.setdefault("mask_cache", {})
_STATE.setdefault("parquet_cache", {})
_STATE.setdefault("mode", str(CLASSIFY_BY_VACUOLE_OR_CELL).lower())
_STATE.setdefault("container", None)
_STATE.setdefault("rows_state", [])
_STATE.setdefault("button", None)
_STATE.setdefault("status", None)
_STATE.setdefault("channel_header", None)
_STATE.setdefault("tile_order_df", None)
_STATE.setdefault("tile_idx", 0)

if "manual_classified_df" not in _STATE or _STATE["manual_classified_df"] is None:
    seed_df = globals().get("manual_classified_df", None)
    if isinstance(seed_df, pd.DataFrame) and not seed_df.empty:
        _STATE["manual_classified_df"] = seed_df.copy()
    else:
        _STATE["manual_classified_df"] = None

if "manual_unclassified_df" not in _STATE or _STATE["manual_unclassified_df"] is None:
    seed_unc = globals().get("manual_unclassified_df", None)
    if isinstance(seed_unc, pd.DataFrame) and not seed_unc.empty:
        _STATE["manual_unclassified_df"] = seed_unc.copy()
    else:
        _STATE["manual_unclassified_df"] = None

_STATE["manual_classified_df"] = _ensure_mc_schema(_STATE["manual_classified_df"], CLASS_TITLE, _KEYS)

if _STATE["manual_unclassified_df"] is None or _STATE["manual_unclassified_df"].empty:
    _STATE["manual_unclassified_df"] = pd.DataFrame(columns=_KEYS)

CHANNEL_NAMES = config["phenotype"]["channel_names"]
if len(set(DISPLAY_CHANNEL)) != len(DISPLAY_CHANNEL):
    raise ValueError("DISPLAY_CHANNEL contains repeated channels. Each must be unique.")
missing = [ch for ch in DISPLAY_CHANNEL if ch not in CHANNEL_NAMES]
if missing:
    raise ValueError(f"DISPLAY_CHANNEL not found in channel_names: {missing}")
CHANNEL_INDICES = [CHANNEL_NAMES.index(ch) for ch in DISPLAY_CHANNEL]

resolved_colors = []
for idx, ch in enumerate(DISPLAY_CHANNEL):
    color_name = CHANNEL_COLORS[idx] if isinstance(CHANNEL_COLORS, list) and idx < len(CHANNEL_COLORS) else None
    if color_name is None:
        resolved_colors.append(("gray", (1.0, 1.0, 1.0)))
    else:
        try:
            rgb = mcolors.to_rgb(color_name)
            resolved_colors.append(("rgb", rgb))
        except ValueError:
            raise ValueError(f"Invalid color '{color_name}' for channel '{ch}'. Use a valid matplotlib color name or hex.")

if _STATE.get("manual_classified_df") is None:
    _STATE["manual_classified_df"] = _ensure_mc_schema(None, CLASS_TITLE, _KEYS)
else:
    _STATE["manual_classified_df"] = _ensure_mc_schema(_STATE["manual_classified_df"], CLASS_TITLE, _KEYS)

if _STATE.get("manual_unclassified_df") is None:
    _STATE["manual_unclassified_df"] = pd.DataFrame(columns=_KEYS)

manual_classified_df = _STATE["manual_classified_df"]
manual_unclassified_df = _STATE["manual_unclassified_df"]

_render_next_batch(
    state=_STATE,
    DISPLAY_CHANNEL=DISPLAY_CHANNEL,
    ADD_TRAINING_DATA=ADD_TRAINING_DATA,
    keys=_KEYS,
    CLASS_TITLE=CLASS_TITLE,
    CLASSIFICATION=CLASSIFICATION,
    RELABEL_CLASSIFICATIONS=RELABEL_CLASSIFICATIONS,
    TRAINING_DATASET_SELECTION=TRAINING_DATASET_SELECTION,
    BATCH_SIZE=BATCH_SIZE,
    mask_summary_df=mask_summary_df,
    mask_instances_df=mask_instances_df,
    mask_instances_out_of_threshold_df=mask_instances_out_of_threshold_df,
    OUT_OF_THRESHOLD_RANDOMIZER=OUT_OF_THRESHOLD_RANDOMIZER,
    CHANNEL_INDICES=CHANNEL_INDICES,
    PHENOTYPE_OUTPUT_FP=PHENOTYPE_OUTPUT_FP,
    CHANNEL_NAMES=CHANNEL_NAMES,
    MODE=CLASSIFY_BY_VACUOLE_OR_CELL,
    RESOLVED_COLORS=resolved_colors,
    SCALE_BAR=SCALE_BAR,
    EXISTING_KEYS=_EXISTING_KEYS,
    THRESHOLD_FEATURE_PRESENT=(THRESHOLD_FEATURE is not None),
)

In [None]:
consolidated_df, training_dataset_out_path = consolidate_manual_classifications(
    manual_classified_df=manual_classified_df,
    class_title=CLASS_TITLE,
    classify_mode=CLASSIFY_BY_VACUOLE_OR_CELL,
    phenotype_output_fp=PHENOTYPE_OUTPUT_FP,
    classifier_output_dir=CLASSIFIER_OUTPUT_DIR,
    write=True,
    verbose=True,
)
print(consolidated_df.head())

## 2. Training

Configure training parameters and train multiple model types to find the best classifier.

### 2a. <font color='red'>SET PARAMETERS</font>: Training Dataset Configuration

- `TRAINING_DATASET_FP`: Path to training dataset(s) in list format. Set to `None` to use the last classified dataset.
- `METADATA_COLS`: Columns to treat as metadata (not features).
- `TRAINING_CHANNELS`: Channels to include in training features.
- `TRAINING_NAME`: Name identifier for this training run.

In [None]:
TRAINING_DATASET_FP = None
METADATA_COLS = ['vacuole_id', 'cell_id', 'plate', 'well', 'tile', 'parasite_count']
TRAINING_CHANNELS = ['DAPI', 'CDPK1']
TRAINING_NAME = "paravacuole"

In [None]:
if TRAINING_DATASET_FP is not None:
    data = load_cellprofiler_data(TRAINING_DATASET_FP)
else:
    print("No training dataset provided, using last classified dataset")
    data = consolidated_df

channel_names = config["phenotype"]["channel_names"]
feature_markers = {c: True for c in channel_names if c in TRAINING_CHANNELS}
exclude_markers = [c for c in channel_names if c not in TRAINING_CHANNELS]

print(f"Class names: {CLASSIFICATION}")
print(f"Class names to stored numeric values: {class_mapping}")
print(f"Features to train upon: {feature_markers}")
print(f"Features to exclude: {exclude_markers}")
print(f"Target column: {CLASS_TITLE}")

### 2b. <font color='red'>SET PARAMETERS</font>: Filter Training Classes

- `REMOVE_MASK_LABELS`: List of class labels to exclude from training (e.g., `["1 parasite"]`). Set to `None` to keep all classes.

In [None]:
REMOVE_MASK_LABELS = None
class_labels, filtered_class_mapping, class_id = filter_classes(CLASSIFICATION, class_mapping, REMOVE_MASK_LABELS)

In [None]:
model_configs = [
    ('rf_standard', 'rf', 'standard', None),
    ('svc_standard', 'svc', 'standard', None),
    ('xgb_standard', 'xgb', 'standard', None),
    ('lgb_standard', 'lgb', 'standard', None),
    ('xgb_robust', 'xgb', 'robust', None),
    ('xgb_minmax', 'xgb', 'minmax', None),
    ('xgb_none', 'xgb', 'none', None),
    ('xgb_none_var', 'xgb', 'none', {'enhance': True, 'remove_low_variance': True, 'remove_correlated': False, 'select_k_best': None}),
    ('xgb_none_corr', 'xgb', 'none', {'enhance': True, 'remove_low_variance': False, 'remove_correlated': True, 'select_k_best': None}),
    ('xgb_none_kbest100', 'xgb', 'none', {'enhance': True, 'remove_low_variance': False, 'remove_correlated': False, 'select_k_best': 100}),
    ('xgb_none_combined', 'xgb', 'none', {'enhance': True, 'remove_low_variance': True, 'remove_correlated': True, 'select_k_best': 100}),
]

In [None]:
pipeline_result = train_classifier_pipeline(
    data=data,
    class_title=CLASS_TITLE,
    class_id=class_id,
    class_labels=class_labels,
    filtered_class_mapping=filtered_class_mapping,
    metadata_cols=METADATA_COLS,
    feature_markers=feature_markers,
    exclude_markers=exclude_markers,
    training_name=TRAINING_NAME,
    model_configs=model_configs,
    classifier_output_dir=CLASSIFIER_OUTPUT_DIR,
    training_channels=TRAINING_CHANNELS,
    verbose=True,
)
multiclass_df = pipeline_result['metrics_df']
print(multiclass_df.head())

run_dir = pipeline_result["dirs"]["run"]
print(f"All outputs written to: {run_dir}")

In [None]:
def _to_list_of_str(value, fallback):
    """Normalize to a list of strings."""
    base = fallback if value is None else value
    if isinstance(base, (list, tuple, set)):
        return [str(x) for x in base]
    return [str(base)]

config["train_classifier"] = {
    "classify_by": str(CLASSIFY_BY_VACUOLE_OR_CELL),
    "class_title": CLASS_TITLE,
    "classification": list(CLASSIFICATION),
    "class_mapping": class_mapping,
    "metadata_cols": list(METADATA_COLS),
    "training_channels": list(TRAINING_CHANNELS),
    "training_datasets": _to_list_of_str(TRAINING_DATASET_FP, training_dataset_out_path),
    "channel_names": list(config.get("phenotype", {}).get("channel_names", [])),
    "last_training_run": str(run_dir),
}

with open(CONFIG_FILE_PATH, "w") as config_file:
    config_file.write("")
    yaml.dump(config, config_file, default_flow_style=False)

print("Updated config['train_classifier'] with current notebook parameters.")

## 3. Selection

Run the cells below to evaluate trained models and select the best one with an appropriate confidence threshold.

In [None]:
MODEL_RUN_DIR = "run_20250904_144239"
TEST_PLATES = ["1", "2"]
TEST_WELLS = ["A1", "A2", "A3"]

### 3a. <font color='red'>SET PARAMETERS</font>: Model Selection Settings

- `MODEL_RUN_DIR`: Path to the model run directory to test (e.g., `"run_20250903_114514"`). Set to `None` to use the last training run.
- `TEST_PLATES`: Plates to use for testing.
- `TEST_WELLS`: Wells to use for testing.

In [None]:
plate = TEST_PLATES
well = TEST_WELLS
classify_by = config["train_classifier"]["classify_by"]
class_title = config["train_classifier"]["class_title"]
class_mapping = config["train_classifier"]["class_mapping"]

if MODEL_RUN_DIR is None:
    CLASSIFIER_DIR_PATH = Path(config["train_classifier"]["last_training_run"])
else:
    CLASSIFIER_DIR_PATH = Path(CLASSIFIER_OUTPUT_DIR) / "classifier" / str(MODEL_RUN_DIR)

plates = plate if isinstance(plate, (list, tuple)) else [plate]
wells = well if isinstance(well, (list, tuple)) else [well]
plates = [str(p) for p in plates]
wells = [str(w) for w in wells]

parquet_dir = Path(PHENOTYPE_OUTPUT_FP) / "parquets"
ctype = str(classify_by).lower()
if ctype == "cell":
    name_suffix = "phenotype_cp.parquet"
elif ctype == "vacuole":
    name_suffix = "phenotype_vacuoles.parquet"
else:
    raise ValueError(f"Unsupported classify_by value: {classify_by}. Use 'cell' or 'vacuole'.")

master_phenotype_df, meta = build_master_phenotype_df(
    plates=plates,
    wells=wells,
    name_suffix=name_suffix,
    parquet_dir=parquet_dir,
    display_fn=display,
    verbose=True
)

CONFIG_FOLDER_PATH = Path("config/")
METADATA_COLS_FP = CONFIG_FOLDER_PATH / "cell_data_metadata_cols.tsv"
METADATA_COLS = config["train_classifier"]["metadata_cols"]
pd.Series(METADATA_COLS).to_csv(METADATA_COLS_FP, index=False, header=False, sep="\t")
metadata, features = split_cell_data(master_phenotype_df, METADATA_COLS)
print(metadata.shape, features.shape)

## <font color='red'>Note</font>

The cell below displays evaluation statistics for the selected model run. If no model was selected, the last trained model will be displayed. Based on the accuracy and F1 scores, select a model in the following cell. Your selection will be saved when you run the final cell of this notebook.

In [None]:
_=display_pngs_in_plots_and_list_models(
    CLASSIFIER_DIR_PATH, width=1200
)

### 3b. <font color='red'>SET PARAMETERS</font>: Model and Montage Settings

- `CLASSIFIER_MODEL`: Name of the model to use (e.g., `"xgb_standard"`). Set to `None` to use the best performing model.
- `MONTAGE_CHANNEL`: Channel to use for montage images. Should be one of the channels used for training.
- `COLLAPSE_COLS`: Columns to collapse on when creating classification summaries (e.g., `["plate", "well"]`).

In [None]:
CLASSIFIER_MODEL = None
MONTAGE_CHANNEL = "CDPK1"
COLLAPSE_COLS = ["plate", "well"]

In [None]:
CLASSIFIER_PATH, model_name = resolve_classifier_model_dill_path(CLASSIFIER_DIR_PATH, CLASSIFIER_MODEL)
print("Selected model: " + model_name)
_ = show_model_evaluation_pngs(CLASSIFIER_DIR_PATH, model_name, width=500)

In [None]:
classifier = CellClassifier.load(CLASSIFIER_PATH)
classified_metadata, classified_features = classifier.classify_cells(metadata, features)

print(classified_metadata.head())
CELL_CLASSES = list(classified_metadata[class_title].unique())

print(class_title + " counts:")
print(classified_metadata[class_title].value_counts())

print("\n" + class_title + " confidences:")
classified_metadata[class_title + "_confidence"].hist()
plt.show()

### 3c. <font color='red'>SET PARAMETERS</font>: Confidence Calibration

- `CONFIDENCE_CORRECTION`: Set to `None` or `"post-hoc"` for post-hoc confidence correction.
- `CALIBRATION_DATASET_FP`: Path to calibration dataset. Set to `None` to use the training dataset.
- `CALIBRATION_METHOD`: Calibration method to use. Options are `"isotonic"` or `"sigmoid"`.

In [None]:
CONFIDENCE_CORRECTION = None
CALIBRATION_DATASET_FP = None
CALIBRATION_METHOD = "isotonic"

In [None]:
if CALIBRATION_DATASET_FP is not None:
    manual_labeled_data = load_cellprofiler_data(CALIBRATION_DATASET_FP)
else:
    print("No calibration dataset provided, using training dataset")
    CALIBRATION_DATASET_FP = config["train_classifier"]["training_datasets"]
    if CALIBRATION_DATASET_FP is None or len(CALIBRATION_DATASET_FP) == 0:
        raise ValueError("No calibration dataset provided and no training dataset found in config.")
    manual_labeled_data = load_cellprofiler_data(CALIBRATION_DATASET_FP)

classified_metadata, meta = calibrate_confidence(
    master_phenotype_df=master_phenotype_df,
    classified_metadata=classified_metadata,
    manual_labeled_data=manual_labeled_data,
    classify_by=classify_by,
    class_title=class_title,
    classifier_path=CLASSIFIER_PATH,
    confidence_correction=CONFIDENCE_CORRECTION,
    calibration_method=CALIBRATION_METHOD,
    test_plate=plates,
    test_well=wells,
    min_samples_isotonic=50,
    verbose=True,
)

print(meta)

In [None]:
fig, axes, montages, titles, ORDERED_CLASSES, summary_df = build_montages_and_summary(
    master_phenotype_df=master_phenotype_df,
    classified_metadata=classified_metadata,
    classify_by=classify_by,
    class_mapping=class_mapping,
    class_title=class_title,
    root_fp=ROOT_FP,
    channels=config["phenotype"]["channel_names"],
    montage_channel=MONTAGE_CHANNEL,
    collapse_cols=COLLAPSE_COLS,
    verbose=True,
    show_figure=True,
    display_fn=display,  # optional
)

In [None]:
DISPLAY_CHANNEL = ["CDPK1", "DAPI"]
CHANNEL_COLORS = ["r", "c"]
SCALE_BAR = 30
MINIMUM_DIFFERENCE = 0.001

In [None]:
w = launch_rankline_ui(
    classified_metadata=classified_metadata,
    class_title=class_title,
    classify_by=classify_by,
    class_mapping=class_mapping,
    phenotype_output_fp=PHENOTYPE_OUTPUT_FP,
    channel_names=list(config["phenotype"]["channel_names"]),
    display_channels=DISPLAY_CHANNEL,
    channel_colors=CHANNEL_COLORS,
    test_plate=plates,
    test_well=wells,
    scale_bar_px=SCALE_BAR,
    minimum_difference=MINIMUM_DIFFERENCE,
    thumbnail_px=150,
    auto_display=True,
)

## Add classifier parameters to config file

Running the cells below will save your selected model and confidence threshold to the config file for use in the aggregate pipeline.

In [None]:
CONFIDENCE_THRESHOLD = 0.59
print("Currently applied model: " + model_name)

In [None]:
config["train_classifier"].update({
    "CLASSIFIER_MODEL": str(CLASSIFIER_MODEL),
    "CLASSIFIER_PATH": str(CLASSIFIER_PATH),
    "CONFIDENCE_THRESHOLD": CONFIDENCE_THRESHOLD,
})

with open(CONFIG_FILE_PATH, "w") as config_file:
    config_file.write("")
    yaml.safe_dump(config, config_file, sort_keys=False, default_flow_style=False)