# Classify WT and Null Genotypes Logistic Regression
Plates 3, 3p, and 5 are used in all splits to classify genotypes either (WT or Null)
The feature selected data is used in all data splits.
Pre-evaluation metrics are stored from all splits and these plates.

In [1]:
import pathlib
import random
from collections import defaultdict

import numpy as np
import pandas as pd
from joblib import dump
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder

from datetime import datetime
import time


In [2]:
MODEL_ID = "xgboost"
ROLE = "train"

import logging
from datetime import datetime
import pathlib

# ============================================
# 1) Choose a RUN_ID
# ============================================
RUN_ID = datetime.now().strftime("%m_%d_%H_%M")

RUN_ID = "12_07_14_30"

RUN_ID = "12_08_07_53"
ANALYSIS_TYPE = "train"


def setup_logger(
    run_id: str,
    model_id: str,
    role: str,
    log_dir: str = "logs",
    analysis_type: str = ANALYSIS_TYPE,
) -> logging.Logger:
    """
    Create a logger that writes to both stdout and a log file.

    - Logger name:  "<analysis_type>_<run_id>_<model_id>_<role>"
    - Log file:     "log_<analysis_type>_<run_id>_<model_id>.log" in `log_dir`
      (shared by all notebooks for the same model & run & analysis_type).
    """
    log_path = pathlib.Path(log_dir)
    log_path.mkdir(exist_ok=True)

    logger_name = f"{analysis_type}_{run_id}_{model_id}_{role}"
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.INFO)
    logger.propagate = False  # don't duplicate logs to root logger

    # Avoid adding handlers multiple times if the cell is re-run
    if not logger.handlers:
        # Common formatter for both handlers
        formatter = logging.Formatter(
            fmt="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
            datefmt="%Y-%m-%dT%H:%M:%S",
        )

        # Stream handler (stdout)
        stream_handler = logging.StreamHandler()
        stream_handler.setLevel(logging.INFO)
        stream_handler.setFormatter(formatter)
        logger.addHandler(stream_handler)

        # File handler (one file per analysis_type + run_id + model_id)
        log_file = log_path / f"log_{analysis_type}_{run_id}_{model_id}.log"
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

    return logger



logger = setup_logger(RUN_ID, MODEL_ID, ROLE)
logger.info("Initialized logger.")

2025-12-06T10:32:46 [train_12_08_07_53_xgboost_train] INFO: Initialized logger.


## Find the root of the git repo on the host system

In [3]:
# Get the current working directory
cwd = pathlib.Path.cwd()

if (cwd / ".git").is_dir():
    root_dir = cwd

else:
    root_dir = None
    for parent in cwd.parents:
        if (parent / ".git").is_dir():
            root_dir = parent
            break

# Check if a Git root directory was found
if root_dir is None:
    raise FileNotFoundError("No Git root directory found.")

## Define paths

### Input

In [4]:
# OPTIONAL: If the data (within the cell painting directory) is stored in a different location, add location here
repo_dir = pathlib.Path(
    "/Users/marktalbot/Documents/VC Studio Homework Folders/HighRisk/nf1_schwann_cell_painting_data"
)

# Set data level
data_level = "cleaned"

# Main directory path (converted or cleaned data)
if data_level == "cleaned":
    data_dir = pathlib.Path(
        repo_dir / "3.processing_features/data/single_cell_profiles/cleaned_sc_profiles"
    )
else:
    data_dir = pathlib.Path(
        repo_dir / "3.processing_features/data/single_cell_profiles"
    )

plate3df_path = pathlib.Path(data_dir / "Plate_3_sc_feature_selected.parquet").resolve(
    strict=True
)
plate3pdf_path = pathlib.Path(
    data_dir / "Plate_3_prime_sc_feature_selected.parquet"
).resolve(strict=True)
plate5df_path = pathlib.Path(data_dir / "Plate_5_sc_feature_selected.parquet").resolve(
    strict=True
)

plate3df = pd.read_parquet(plate3df_path)
plate3pdf = pd.read_parquet(plate3pdf_path)
plate5df = pd.read_parquet(plate5df_path)

logger.info("Number of single-cells total per plate:")
logger.info(f"Plate 3: {plate3df.shape[0]}")
logger.info(f"Plate 3 prime: {plate3pdf.shape[0]}")
logger.info(f"Plate 5: {plate5df.shape[0]}")

# Set the seed
rng = np.random.default_rng(0)


2025-12-06T10:32:46 [train_12_08_07_53_xgboost_train] INFO: Number of single-cells total per plate:
2025-12-06T10:32:46 [train_12_08_07_53_xgboost_train] INFO: Plate 3: 10206
2025-12-06T10:32:46 [train_12_08_07_53_xgboost_train] INFO: Plate 3 prime: 5126
2025-12-06T10:32:46 [train_12_08_07_53_xgboost_train] INFO: Plate 5: 5348


### Outputs

In [5]:
data_path = pathlib.Path("data_xgboost")
data_path.mkdir(parents=True, exist_ok=True)


## Splitting and Processing
Functions to split and process data

In [6]:
gene_column = "Metadata_genotype"


def down_sample_by_genotype(_df):
    """
    Return an equal number of cells from each genotype.
    The number of cells in a genotype is the minimum number of cells from all genotypes.

    Parameters
    ----------
    _df: Pandas DataFrame
        The data to be downsampled by the gene_column column.

    Returns
    -------
    The DataFrame downsampled by genotype.
    """
    min_gene = _df[gene_column].value_counts().min()
    return _df.groupby(gene_column, group_keys=False).apply(
        lambda x: x.sample(n=min_gene, random_state=0)
    )


def process_plates(_df):
    """
    Drop rows with NaNs from the single cell data and remove HET cells.

    Parameters
    ----------
    _df: Pandas DataFrame
        Uncleaned plate data with NaNs and HET cells to be removed. Contains the column "Metadata_genotype".

    Returns
    -------
    _df: Pandas DataFrame
        Cleaned single cell data by removing NaNs and HET cells.
    """
    _df.dropna(inplace=True)
    _df = _df.loc[_df[gene_column] != "HET"]
    return _df


def shuffle_data(_X):
    """
    Shuffle the columns of the input DataFrame independently.

    Parameters
    ----------
    _X: Pandas DataFrame
        Input feature data for shuffling the columns.
    """
    for column in _X.columns:
        _X[column] = rng.permutation(_X[column])


def store_pre_evaluation_data(_X, _y, _metadata, _datasplit):
    """
    Store model data to evaluate performance.

    Parameters
    ----------
    _X: Pandas DataFrame
        Feature DataFrame from a given plate and data split.
    _y: numpy.ndarray
        A numerically encoded label vector ordered according to _X.
    _metadata: Pandas DataFrame
        Metadata (one row per sample) to carry through.
    _datasplit: str
        Name of the datasplit (for example, "train", "val", "test", "shuffled_*").
    """
    # "model" is the currently trained classifier (either CV fold or final model)
    eval_data[f"probability_{probability_class}"].extend(
        model.predict_proba(_X)[:, 1].tolist()
    )
    eval_data["datasplit"].extend([_datasplit] * _X.shape[0])
    eval_data["predicted_genotype"].extend(model.predict(_X).tolist())
    eval_data["true_genotype"].extend(_y.tolist())
    for meta_col in _metadata.columns:
        eval_data[meta_col].extend(_metadata[meta_col].tolist())


## Split and process plates

In [7]:
def create_splits(_wells, _plate):
    """
    Create data splits for model training. The splits are rest (train and validation) and test.

    Parameters
    ----------
    _wells: List(String)
        The well names from which single cells will be used in the test set.

    _plate: Pandas Dataframe
        Single cell data from one of the plate's containing a "Metadata_Well" column.

    Returns
    -------
    Dataframes of the split single cell data.
    """

    return (
        _plate[~_plate["Metadata_Well"].isin(_wells)],
        _plate[_plate["Metadata_Well"].isin(_wells)],
    )

In [8]:
plate3df = process_plates(plate3df)
p3_wells = ["C11", "E11", "C3", "F3"]
rest3df, test3df = create_splits(p3_wells, plate3df)
rest3df, test3df = down_sample_by_genotype(rest3df), down_sample_by_genotype(test3df)

plate3pdf = process_plates(plate3pdf)
p3p_wells = ["F11", "G11", "C3", "F3"]
rest3pdf, test3pdf = create_splits(p3p_wells, plate3pdf)
rest3pdf, test3pdf = down_sample_by_genotype(rest3pdf), down_sample_by_genotype(
    test3pdf
)

plate5df = process_plates(plate5df)
p5_wells = ["C9", "E11", "E3", "G3"]
rest5df, test5df = create_splits(p5_wells, plate5df)
rest5df, test5df = down_sample_by_genotype(rest5df), down_sample_by_genotype(test5df)

## Combine plate columns across each data split

In [9]:
# Columns common to all plates
plate_cols = list(
    set(plate5df.columns) & set(plate3df.columns) & set(plate3pdf.columns)
)

restdf = pd.concat(
    [rest3df[plate_cols], rest3pdf[plate_cols], rest5df[plate_cols]], ignore_index=True
).reset_index(drop=True)

testdf = pd.concat(
    [test3df[plate_cols], test3pdf[plate_cols], test5df[plate_cols]], ignore_index=True
).reset_index(drop=True)

## Encode genotypes and extract feature data

In [10]:
meta_cols = testdf.filter(like="Metadata").columns
feat_cols = testdf.drop(columns=meta_cols).columns

In [11]:
le = LabelEncoder()

y = le.fit_transform(restdf["Metadata_genotype"])
X = restdf.drop(columns=meta_cols)

y_test = le.fit_transform(testdf["Metadata_genotype"])
X_test = testdf.drop(columns=meta_cols)

# Class for saving probabilities
probability_class = le.inverse_transform([1])[0]

# Train Models

## Specify parameters for training

In [12]:
# Base XGBoost parameters (not searched)
xgb_params = {
    "objective": "binary:logistic",
    "eval_metric": "logloss",
    "n_jobs": -1,
    "random_state": 0,
    "tree_method": "hist",
}

# Discrete search space for XGBoost hyperparameters
param_choices = {
    "n_estimators": [300, 600, 900],
    "max_depth": [3, 4, 5],
    "learning_rate": [0.05, 0.1],
    "subsample": [0.7, 0.9, 1.0],
    "colsample_bytree": [0.7, 0.9, 1.0],
    "min_child_weight": [1, 3, 5],
    "gamma": [0.0, 0.1, 0.3],
}

# Number of random hyperparameter combinations to try
rand_iter = 20  # you can increase to 30 or more if runtime is acceptable

# Number of CV folds
n_splits = 5

# Track best performance
best_acc = -np.inf
best_hp = None
best_eval_data = None

# Generate random hyperparameter samples
random_params = {
    i: {key: random.choice(values) for key, values in param_choices.items()}
    for i in range(rand_iter)
}


## Hyperparameter search

In [13]:
# Hyperparameter search for XGBoost (with timing logs)
overall_start = time.time()

for idx, rparams in random_params.items():
    iter_start = time.time()
    print(
        f"{datetime.now().isoformat(timespec='seconds')} "
        f"[INFO] Iteration {idx + 1}/{len(random_params)}: trying params = {rparams}"
    )

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)

    # Combine fixed XGBoost params with the sampled hyperparameters
    comb_params = xgb_params | rparams

    # Reset eval data for this hyperparameter setting
    eval_data = defaultdict(list)
    acc = 0.0

    # Loop through the folds
    for fold, (train_index, val_index) in enumerate(skf.split(X, y), start=1):
        fold_start = time.time()

        X_train, X_val = X.iloc[train_index], X.iloc[val_index]
        y_train, y_val = y[train_index], y[val_index]

        # Create a shuffled version of the validation features for the null model
        X_val_shuf = X_val.copy()
        shuffle_data(X_val_shuf)

        # Train XGBoost model on this fold
        model = XGBClassifier(**comb_params)
        model.fit(X_train, y_train)

        # Evaluate on validation fold
        preds = model.predict(X_val)
        acc += accuracy_score(y_val, preds)

        # Store predictions for evaluation (real and shuffled)
        store_pre_evaluation_data(
            X_val, y_val, restdf.iloc[val_index][meta_cols], "val"
        )
        store_pre_evaluation_data(
            X_val_shuf, y_val, restdf.iloc[val_index][meta_cols], "shuffled_val"
        )

        fold_time = (time.time() - fold_start) / 60
        logger.info(
            f"{datetime.now().isoformat(timespec='seconds')} "
            f"[INFO]    Fold {fold}/{n_splits} finished in {fold_time:.2f} min"
        )

    # Average accuracy across folds
    acc = acc / n_splits

    # Keep best performing hyperparameters and their eval data
    if acc > best_acc:
        best_acc = acc
        best_hp = rparams
        best_eval_data = {k: v.copy() for k, v in eval_data.items()}

    iter_time = (time.time() - iter_start) / 60
    logger.info(
        f"{datetime.now().isoformat(timespec='seconds')} "
        f"[INFO] Iteration {idx + 1} finished in {iter_time:.2f} min "
        f"(acc={acc:.4f}, best_acc={best_acc:.4f})"
    )

total_time = (time.time() - overall_start) / 60
logger.info(f"{datetime.now().isoformat(timespec='seconds')} [INFO] Total search time: {total_time:.2f} min")
logger.info(f"Best average validation accuracy = {best_acc}")
logger.info(f"Best hyperparameters = {best_hp}")

# Set eval_data to the best hyperparameter results for saving later
eval_data = defaultdict(list)
for k, v in best_eval_data.items():
    eval_data[k].extend(v)


2025-12-06T10:32:47 [INFO] Iteration 1/20: trying params = {'n_estimators': 900, 'max_depth': 4, 'learning_rate': 0.05, 'subsample': 0.7, 'colsample_bytree': 0.9, 'min_child_weight': 1, 'gamma': 0.3}


KeyboardInterrupt: 

## Retrain model

In [None]:
# Retrain XGBoost on all training data using best hyperparameters
comb_params = xgb_params | best_hp

model = XGBClassifier(**comb_params)
model.fit(X, y)


## Shuffle train and validation data

In [None]:
X_shuf = X.copy()
shuffle_data(X_shuf)

X_test_shuf = X_test.copy()
shuffle_data(X_test_shuf)

# Save models and model data

## Store pre-evaluation split data

In [None]:
store_pre_evaluation_data(X, y, restdf[meta_cols], "train")
store_pre_evaluation_data(X_shuf, y, restdf[meta_cols], "shuffled_train")

store_pre_evaluation_data(X_test, y_test, testdf[meta_cols], "test")
store_pre_evaluation_data(X_test_shuf, y_test, testdf[meta_cols], "shuffled_test")

In [None]:
suffix = "_qc" if data_level == "cleaned" else ""

dump(model, f"{data_path}/trained_nf1_model{suffix}.joblib")
dump(le, f"{data_path}/trained_nf1_model_label_encoder{suffix}.joblib")
pd.DataFrame(eval_data).to_parquet(
    f"{data_path}/nf1_model_pre_evaluation_results{suffix}.parquet"
)
