# CellSexID: Human Single-Cell Sex Prediction Analysis

This notebook demonstrates comprehensive sex prediction across multiple human single-cell RNA-seq datasets using machine learning approaches. We analyze four main datasets to validate the robustness and generalizability of sex classification methods:

## Datasets Analyzed
1. **Data1 ATL Dataset (GSE294224)**: Adult T-cell leukemia-lymphoma samples
2. **Kidney Dataset (GSE151671)**: Donor kidney cells for transplant analysis  
3. **Data2 (AML/MLL Dataset, GSE289435)**: Bone marrow mononuclear cells from AML patients
4. **Data3 (Thymic Dataset, GSE262749)**: Medullary thymic epithelial cells

All datasets in this analysis undergo identical preprocessing steps to ensure standardized and comparable results across different studies:

### **Standardized Quality Control & Preprocessing:**

1. **Gene Filtering**: `min_cells = 3` - Remove genes expressed in fewer than 3 cells
2. **Cell Filtering**: `min_genes = 200` - Remove cells expressing fewer than 200 genes  
3. **Mitochondrial Filtering**: `mt < 5%` - Remove cells with >5% mitochondrial gene expression
4. **Normalization**: `target_sum = 1e4` - Library-size normalize to 10,000 counts per cell
5. **Log Transformation**: `log1p` - Apply log(x + 1) transformation


## Methodology Overview
- **Feature Selection**: Sex-specific marker genes from X/Y chromosomes
- **Models**: Logistic Regression, Linear SVM, XGBoost, Random Forest
- **Validation Strategy**: Cross-dataset validation and statistical robustness testing
- **Performance Metrics**: AUROC, AUPRC, accuracy, precision, recall, F1-score

## Key Contributions
- Cross-dataset generalization analysis
- Optimal marker gene panel identification
- Statistical validation with multiple random seeds
- Performance comparison across tissue types and disease states

## Dataset 1: ATL (Adult T-cell Leukemia-Lymphoma) Analysis

## Data Processing and Quality Control
Processing of three 10X Genomics datasets (ATL1, ATL2, ATL3) from adult T-cell leukemia-lymphoma samples. This analysis includes data loading, quality control filtering, and preparation for downstream sex prediction modeling.

In [2]:
#!/usr/bin/env python3
"""
ATL scRNA-seq merge + QC • GSE294224
─────────────────────────────────────
- Load three 10X Genomics datasets (ATL1, ATL2, ATL3)
- Prefix barcodes with sample name
- Concatenate matrices
- Assign sex labels based on patient information
- Perform QC filtering (min_genes, mitochondrial content)
- Library-size normalize, log1p transform
- Save sparse .h5ad
"""

import os
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import scipy.io as spio  # Correct import for mmread
import sys

# ───────────── USER PATHS ─────────────
DATA_DIR = "data/GSE294224_RAW"
OUTDIR = "data/human_1"
OUTFILE = os.path.join(OUTDIR, "atl_merged_qc.h5ad")

# Create output directory if it doesn't exist
os.makedirs(OUTDIR, exist_ok=True)

# Sample information
samples = ["ATL1", "ATL2", "ATL3"]
# Sex labels according to your data
sex_labels = {"ATL1": "female", "ATL2": "male", "ATL3": "male"}

VERBOSE = True  # one-line progress prints

# ───────────── HELPERS ──────────────
def load_10x_data(sample: str, data_dir: str) -> sc.AnnData:
    """Load one 10X Genomics dataset using direct file loading."""
    if VERBOSE: print(f"• Loading {sample}")
    
    # Define file patterns based on the actual directory listing
    sample_id_mapping = {
        "ATL1": "8900566",
        "ATL2": "8900568",
        "ATL3": "8900570"
    }
    
    sample_id = sample_id_mapping[sample]
    
    # Correct file paths based on your directory listing
    matrix_file = os.path.join(data_dir, f"GSM{sample_id}_{sample}_matrix.mtx.gz")
    features_file = os.path.join(data_dir, f"GSM{sample_id}_{sample}_features.tsv.gz")
    barcodes_file = os.path.join(data_dir, f"GSM{sample_id}_{sample}_barcodes.tsv.gz")
    
    # Check if files exist
    for file_path in [matrix_file, features_file, barcodes_file]:
        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
    
    try:
        # Load matrix - using the correct scipy.io module
        X = spio.mmread(matrix_file).T.tocsr()
        
        # Load features (genes)
        features = pd.read_csv(features_file, sep='\t', header=None)
        if features.shape[1] >= 2:
            # The second column typically contains gene symbols
            gene_names = features[1].values
        else:
            # If there's only one column, use it as both ID and name
            gene_names = features[0].values
        
        # Make gene names unique
        gene_names_unique = make_unique_names(gene_names)
        
        # Load barcodes
        barcodes = pd.read_csv(barcodes_file, sep='\t', header=None)[0].values
        
        # Create AnnData object
        adata = sc.AnnData(X=X)
        adata.obs_names = [f"{sample}_{bc}" for bc in barcodes]
        adata.var_names = gene_names_unique
        
        # Add sample and sex information
        adata.obs["sample"] = sample
        adata.obs["sex"] = sex_labels[sample]
        
        return adata
    
    except Exception as e:
        print(f"Error loading {sample}: {e}")
        return None

def make_unique_names(names):
    """Make duplicate names unique by appending numbers."""
    name_counts = {}
    unique_names = []
    
    for name in names:
        if name in name_counts:
            name_counts[name] += 1
            unique_names.append(f"{name}_{name_counts[name]}")
        else:
            name_counts[name] = 0
            unique_names.append(name)
    
    return unique_names

# ────────── 1) LOAD ALL SAMPLES ──────────
adatas = []
for sample in samples:
    adata = load_10x_data(sample, DATA_DIR)
    if adata is not None:
        adatas.append(adata)
        if VERBOSE:
            print(f"• Successfully loaded {sample}: {adata.n_obs} cells × {adata.n_vars} genes")
    else:
        print(f"❌ Failed to load {sample}")

# Check if we have data to work with
if not adatas:
    print("No data was successfully loaded. Exiting.")
    sys.exit(1)

# ────────── 2) CONCATENATE ──────────
if VERBOSE: print("• Concatenating datasets")
if len(adatas) == 1:
    # If only one dataset was loaded, skip concatenation
    adata = adatas[0].copy()
    print("Only one dataset was loaded, skipping concatenation")
else:
    # Concatenate multiple datasets
    adata = sc.concat(
        adatas,
        join="outer",     # Union of genes
        merge="first",
        fill_value=0,     # Fill gaps with 0
    )

# Ensure the data is sparse
if not sp.issparse(adata.X):
    adata.X = sp.csr_matrix(adata.X)

if VERBOSE:
    print(f"• Concatenated: {adata.n_obs:,} cells × {adata.n_vars:,} genes")

# ────────── 3) BASIC QC & FILTERING ──────────
if VERBOSE: print("• Performing QC and filtering")

# Filter genes with low expression
sc.pp.filter_genes(adata, min_cells=3)

# Filter cells with few expressed genes
sc.pp.filter_cells(adata, min_genes=200)

# Identify mitochondrial genes (human MT genes start with MT-)
adata.var["mt"] = adata.var_names.str.startswith("MT-")

# Calculate QC metrics
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)

# Filter cells with high mitochondrial content
max_mito_percent = 5
adata = adata[adata.obs["pct_counts_mt"] < max_mito_percent, :].copy()

# ────────── 4) NORMALIZATION ──────────
if VERBOSE: print("• Normalizing data")
sc.pp.normalize_total(adata, target_sum=1e4, inplace=True)
sc.pp.log1p(adata)

# ────────── 5) WRITE OUTPUT ──────────
if VERBOSE: print(f"• Writing to {OUTFILE}")
adata.write(OUTFILE, compression="gzip")

# ────────── SUMMARY ──────────
print("✅ Finished:")
print(f"   Cells  : {adata.n_obs:,}")
print(f"   Genes  : {adata.n_vars:,}")
print("   Sex    :")
print(adata.obs["sex"].value_counts(dropna=False))
print("   Samples:")
print(adata.obs["sample"].value_counts())

• Loading ATL1
• Successfully loaded ATL1: 5066 cells × 36610 genes
• Loading ATL2
• Successfully loaded ATL2: 10151 cells × 36610 genes
• Loading ATL3
• Successfully loaded ATL3: 5095 cells × 36610 genes
• Concatenating datasets
• Concatenated: 20,312 cells × 36,610 genes
• Performing QC and filtering
• Normalizing data
• Writing to data/human_1/atl_merged_qc.h5ad
✅ Finished:
   Cells  : 9,292
   Genes  : 22,630
   Sex    :
sex
male      6481
female    2811
Name: count, dtype: int64
   Samples:
sample
ATL2    3982
ATL1    2811
ATL3    2499
Name: count, dtype: int64


## Cross-Dataset Validation: Data2(MLL) → Data1(ATL)
**Training Data**: Data2 (MLL/AML dataset, GSE289435) - Subsample
**Testing Data**: ATL dataset (GSE294224) - Data1

Training sex classification models on bone marrow mononuclear cells (Data2) and testing on adult T-cell leukemia-lymphoma samples (ATL/Data1). This cross-dataset validation tests generalizability from AML to ATL disease contexts.

In [6]:
#!/usr/bin/env python3
"""
Sex classification using MLL data (human_2) subsample as training and ATL data (human_1) as testing
Using 10 selected marker genes: RPS4Y1, EIF1AY, XIST, DDX3Y, UTY, KDM5D, IFIT3, IFIT2, RPS4X, RPL29
Models: LogisticRegression, Linear-SVC, XGBoost, Random-Forest
"""

# ─── imports ───────────────────────────────────────────────
import os, pathlib, warnings
import numpy as np, pandas as pd, scanpy as sc, scipy.sparse as sp
import matplotlib.pyplot as plt

from sklearn.impute            import SimpleImputer
from sklearn.preprocessing     import StandardScaler
from sklearn.pipeline          import Pipeline
from sklearn.linear_model      import LogisticRegression
from sklearn.svm               import SVC
from sklearn.ensemble          import RandomForestClassifier
from sklearn.model_selection   import train_test_split
from xgboost                   import XGBClassifier
from sklearn.metrics           import (
    accuracy_score, f1_score, roc_auc_score, average_precision_score,
    confusion_matrix, roc_curve, precision_recall_curve
)

# ─── selected marker panel ──────────────────────────────────────────
SELECTED_MARKERS =  [
    "RPS4Y1", 
    "EIF1AY", 
    "XIST", 
    "DDX3Y", 
    "UTY", 
    "KDM5D", 
    "IFIT3", 
    "IFIT2", 
    "RPS4X"
]

 
# ─── alias dictionary for gene name mapping ───────────────────
alias_to_official = {
    "XIST":"Xist", "RPS27RT":"Rps27rt", "DDX3Y":"Ddx3y", "RPL35":"Rpl35",
    "EIF2S3Y":"Eif2s3y", "EIF2S3L":"Eif2s3y", "GM42418":"Gm42418", "UBA52":"Uba52",
    "RPL36A-PS1":"Rpl36a-ps1", "KDM5D":"Kdm5d", "JARID1D":"Kdm5d", "WDR89":"Wdr89",
    "UTY":"Uty", "LARS2":"Lars2", "AY036118":"AY036118", "RPL9-PS6":"Rpl9-ps6", "RPS27":"Rps27",
    "RPS4Y1":"RPS4Y1", "EIF1AY":"EIF1AY", "IFIT3":"IFIT3", "IFIT2":"IFIT2", 
    "RPS4X":"RPS4X", "RPL29":"RPL29", 
    # Keep human gene names as-is since both datasets are human
}

# ─── file paths ────────────────────────────────────────────
DATA_DIR = "data"
MLL_H5AD = os.path.join(DATA_DIR, "human_2/mll_merged_qc.h5ad")
ATL_H5AD = os.path.join(DATA_DIR, "human_1/atl_merged_qc.h5ad")  # Now using ATL as test data
OUT_DIR = pathlib.Path("result/human2_human1_selected")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ─── helper functions ────────────────────────────────────────
def unify_gene_symbols(adata):
    """Normalize gene symbols using alias dictionary"""
    if not isinstance(adata.var_names, pd.Index):
        return adata
    
    # Create a mapping dictionary for renaming
    rename_dict = {}
    for gene in adata.var_names:
        # Check for aliases (case insensitive)
        gene_upper = gene.upper()
        if gene_upper in alias_to_official:
            rename_dict[gene] = alias_to_official[gene_upper]
    
    # Rename genes if aliases are found
    if rename_dict:
        print(f"Renaming {len(rename_dict)} genes using alias dictionary")
        adata.var_names = [rename_dict.get(g, g) for g in adata.var_names]
        
    # Make variable names unique if needed
    if not adata.var_names.is_unique:
        print("Making gene names unique")
        adata.var_names_make_unique()
        
    return adata

def extract_sex_labels(adata):
    """Extract standardized sex labels (0=female, 1=male)"""
    if "sex" not in adata.obs:
        raise ValueError("'sex' column not found in AnnData.")

    sex = (
        adata.obs["sex"]
          .astype(str).str.strip().str.lower()
          .map({"female": 0, "male": 1})
    )
    mask = sex.notna()
    return adata[mask].copy(), sex[mask].astype(int).values

def make_pipe(clf):
    """Create a preprocessing pipeline for a classifier"""
    steps = [("imp", SimpleImputer(strategy="median"))]
    if isinstance(clf, (LogisticRegression, SVC)):
        steps.append(("sc", StandardScaler(with_mean=False)))
    steps.append(("clf", clf))
    return Pipeline(steps)

def extract_marker_matrix(adata, markers):
    """Extract marker gene expression matrix from AnnData"""
    # Convert var_names to lowercase for case-insensitive matching
    var_lower = {g.lower(): g for g in adata.var_names}
    
    # Find markers present in dataset (case-insensitive)
    present = [var_lower[g.lower()] for g in markers if g.lower() in var_lower]
    
    if len(present) < 2:
        raise ValueError(f"Fewer than 2 marker genes present in dataset. Found: {present}")
    
    # Extract expression matrix as DataFrame
    X_df = pd.DataFrame(
        adata[:, present].X.A if sp.issparse(adata.X) else adata[:, present].X,
        index=adata.obs_names,
        columns=present,
    )
    
    # Drop constant columns that don't provide information
    nonconst = (X_df != X_df.iloc[0]).any()
    if (~nonconst).any():
        dropped = X_df.columns[~nonconst].tolist()
        warnings.warn(f"Dropping constant marker(s): {dropped}")
        X_df = X_df.loc[:, nonconst]
        present = X_df.columns.tolist()
    
    if len(present) < 2:
        raise ValueError("Need ≥2 informative markers after filtering.")
    
    print(f"Markers used ({len(present)}): {present}")
    
    return X_df, present

# ═════════ 1) Load datasets ══════════════════════════
print("Loading MLL dataset (training data)...")
mll_adata = sc.read_h5ad(MLL_H5AD)
mll_adata = unify_gene_symbols(mll_adata)
mll_adata, mll_y = extract_sex_labels(mll_adata)
print(f"MLL dataset: {mll_adata.n_obs:,} cells  "
      f"(♀ {(mll_y==0).sum()}  ♂ {(mll_y==1).sum()})")

print("\nLoading ATL dataset (test data)...")  # Changed to ATL as test data
test_adata = sc.read_h5ad(ATL_H5AD)
test_adata = unify_gene_symbols(test_adata)
test_adata, y_test = extract_sex_labels(test_adata)
print(f"ATL dataset: {test_adata.n_obs:,} cells  "  # Changed to ATL
      f"(♀ {(y_test==0).sum()}  ♂ {(y_test==1).sum()})")

# ═════════ 2) Subsample MLL dataset (1/15) ══════════════════
print("\nExtracting 1/15 random subsample from MLL dataset...")
subsample_size = len(mll_adata) // 15

indices = np.arange(len(mll_adata))
_, subsample_indices, _, y_subsample = train_test_split(
    indices, mll_y, test_size=subsample_size/len(mll_adata), 
    stratify=mll_y, random_state=42
)

# Create subsampled AnnData object for training
train_adata = mll_adata[subsample_indices].copy()
y_train = mll_y[subsample_indices]
print(f"Training subsample: {train_adata.n_obs:,} cells  "
      f"(♀ {(y_train==0).sum()}  ♂ {(y_train==1).sum()})")
print(f"Subsampling ratio: {train_adata.n_obs/mll_adata.n_obs:.1%} of original MLL data")

# ═════════ 3) Extract marker matrices ══════════════════
print("\nExtracting selected marker genes from training data...")
X_train_df, train_markers = extract_marker_matrix(train_adata, SELECTED_MARKERS)

print("\nExtracting selected marker genes from test data...")
X_test_df, test_markers = extract_marker_matrix(test_adata, SELECTED_MARKERS)

# Find common markers between train and test sets
common_markers = sorted(set(train_markers) & set(test_markers))
if len(common_markers) < 2:
    raise ValueError(f"Fewer than 2 common marker genes between datasets. Found: {common_markers}")

print(f"\nCommon markers used for training and testing ({len(common_markers)}): {common_markers}")

# Use only common markers
X_train = X_train_df[common_markers].values
X_test = X_test_df[common_markers].values

# ═════════ 4) Define models ══════════════════════════
pipelines = {
    "LogisticRegression": make_pipe(LogisticRegression(max_iter=1000, random_state=42)),
    "LinearSVC": make_pipe(SVC(kernel="linear", probability=True, random_state=42)),
    "XGBoost": make_pipe(XGBClassifier(
        eval_metric="logloss", random_state=42,
        n_estimators=100, learning_rate=0.05, max_depth=10)),
    "RandomForest": make_pipe(RandomForestClassifier(max_depth=10, random_state=42)),
}

# Set up for curve data collection and plotting
curve_data_roc = []
curve_data_pr = []
colors = {
    "LogisticRegression": "blue",
    "LinearSVC": "red",
    "XGBoost": "green",
    "RandomForest": "purple"
}

# Create figures for plotting
fig_roc, ax_roc = plt.subplots(figsize=(10, 8))
fig_pr, ax_pr = plt.subplots(figsize=(10, 8))

# ═════════ 5) Train and evaluate models ══════════════════
print("\n" + "="*50)
print("Training and evaluating models using selected genes")
print("="*50)

results = []
for name, model in pipelines.items():
    print(f"\n=== {name} ===")
    model.fit(X_train, y_train)

    # 1) Train performance
    p_tr = model.predict(X_train)
    prob_tr = model.predict_proba(X_train)[:, 1]
    tr_acc = accuracy_score(y_train, p_tr)
    tr_f1 = f1_score(y_train, p_tr)
    tr_roc = roc_auc_score(y_train, prob_tr)
    tr_pr = average_precision_score(y_train, prob_tr)
    print(f" TRAIN → Acc={tr_acc:.4f}, F1={tr_f1:.4f}, AUROC={tr_roc:.4f}, AUPRC={tr_pr:.4f}")

    # 2) Test performance
    p_test = model.predict(X_test)
    prob_test = model.predict_proba(X_test)[:, 1]
    test_acc = accuracy_score(y_test, p_test)
    test_f1 = f1_score(y_test, p_test)
    test_roc = roc_auc_score(y_test, prob_test)
    test_pr = average_precision_score(y_test, prob_test)
    print(f" TEST → Acc={test_acc:.4f}, F1={test_f1:.4f}, AUROC={test_roc:.4f}, AUPRC={test_pr:.4f}")
    print("  Confusion Matrix:")
    print(confusion_matrix(y_test, p_test))
    
    results.append({
        "Model": name,
        "Train_Acc": tr_acc, "Train_F1": tr_f1,
        "Train_AUROC": tr_roc, "Train_AUPRC": tr_pr,
        "Test_Acc": test_acc, "Test_F1": test_f1,
        "Test_AUROC": test_roc, "Test_AUPRC": test_pr,
    })
    
    # Calculate ROC curve points
    fpr, tpr, _ = roc_curve(y_test, prob_test)
    roc_df = pd.DataFrame({"model": name, "fpr": fpr, "tpr": tpr})
    curve_data_roc.append(roc_df)
    
    # Calculate PR curve points
    precision, recall, _ = precision_recall_curve(y_test, prob_test)
    pr_df = pd.DataFrame({"model": name, "precision": precision, "recall": recall})
    curve_data_pr.append(pr_df)
    
    # Plot ROC curve
    ax_roc.plot(fpr, tpr, lw=2, color=colors[name], 
             label=f'{name} (area = {test_roc:.3f})')
    
    # Plot PR curve
    ax_pr.plot(recall, precision, lw=2, color=colors[name], 
            label=f'{name} (area = {test_pr:.3f})')

# ═════════ 6) Save results ══════════════════════════
# Combine and save curve data
all_roc_data = pd.concat(curve_data_roc, ignore_index=True)
all_pr_data = pd.concat(curve_data_pr, ignore_index=True)

all_roc_data.to_csv(OUT_DIR / "human2_to_human1_selected_auroc.csv", index=False)
all_pr_data.to_csv(OUT_DIR / "human2_to_human1_selected_auprc.csv", index=False)

# Finalize and save ROC plot
ax_roc.plot([0, 1], [0, 1], 'k--', lw=2)
ax_roc.set_xlim([0.0, 1.0])
ax_roc.set_ylim([0.0, 1.05])
ax_roc.set_xlabel('False Positive Rate')
ax_roc.set_ylabel('True Positive Rate')
ax_roc.set_title('Human2 (MLL) → Human1 (ATL): ROC Curves (Selected Genes)')
ax_roc.legend(loc="lower right")
ax_roc.grid(True, linestyle='--', alpha=0.7)
fig_roc.tight_layout()
fig_roc.savefig(OUT_DIR / "human2_to_human1_selected_roc_curves.png", dpi=300, bbox_inches='tight')

# Finalize and save PR plot
ax_pr.set_xlabel('Recall')
ax_pr.set_ylabel('Precision')
ax_pr.set_ylim([0.0, 1.05])
ax_pr.set_xlim([0.0, 1.0])
ax_pr.set_title('Human2 (MLL) → Human1 (ATL): Precision-Recall Curves (Selected Genes)')
ax_pr.legend(loc="lower left")
ax_pr.grid(True, linestyle='--', alpha=0.7)
fig_pr.tight_layout()
fig_pr.savefig(OUT_DIR / "human2_to_human1_selected_pr_curves.png", dpi=300, bbox_inches='tight')

plt.close('all')

# Save summary results
results_df = pd.DataFrame(results)
print("\nFinal results:")
print(results_df)

results_df.to_csv(OUT_DIR / "human2_to_human1_selected_summary_results.csv", index=False)
print(f"\nAll results saved to {OUT_DIR}")

Loading MLL dataset (training data)...
Renaming 15 genes using alias dictionary
MLL dataset: 87,171 cells  (♀ 44716  ♂ 42455)

Loading ATL dataset (test data)...
Renaming 15 genes using alias dictionary
ATL dataset: 9,292 cells  (♀ 2811  ♂ 6481)

Extracting 1/15 random subsample from MLL dataset...
Training subsample: 5,811 cells  (♀ 2981  ♂ 2830)
Subsampling ratio: 6.7% of original MLL data

Extracting selected marker genes from training data...
Markers used (9): ['RPS4Y1', 'EIF1AY', 'Xist', 'Ddx3y', 'Uty', 'Kdm5d', 'IFIT3', 'IFIT2', 'RPS4X']

Extracting selected marker genes from test data...
Markers used (9): ['RPS4Y1', 'EIF1AY', 'Xist', 'Ddx3y', 'Uty', 'Kdm5d', 'IFIT3', 'IFIT2', 'RPS4X']

Common markers used for training and testing (9): ['Ddx3y', 'EIF1AY', 'IFIT2', 'IFIT3', 'Kdm5d', 'RPS4X', 'RPS4Y1', 'Uty', 'Xist']

Training and evaluating models using selected genes

=== LogisticRegression ===
 TRAIN → Acc=0.8954, F1=0.8816, AUROC=0.9615, AUPRC=0.9641
 TEST → Acc=0.9523, F1=0.96

## Kidney Dataset: Donor Cell Analysis (GSE151671)

## Overview
Analysis of donor kidney cells for unbiased expression-sex analysis. This dataset serves as an important validation set for testing model generalizability across tissue types, as kidney samples represent healthy donor tissue without disease-related expression changes.

In [None]:
#!/usr/bin/env python3
"""
Kidney scRNA-seq merge + QC (UNBIASED VERSION)  •  GSE151671
──────────────────────────────────────────────────────────
FOCUS: Only donor kidney cells for unbiased expression-sex analysis
• Load three DGE tables (HK, AK1, AK2)  
• Filter to keep ONLY parenchymal/kidney-resident cells (donor cells)
• Assign sex based on DONOR sex (not recipient)
• Remove recipient immune cells that confound sex-expression analysis
• Save clean donor kidney dataset
"""

import os, warnings
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp

# ───────────── USER PATHS ─────────────
DATA_DIR = "data/GSE151671_RAW"
data_files = {
    "HK":  "GSM4587971_HK.dge.txt.gz",    # female healthy kidney
    "AK1": "GSM4587972_AK1.dge.txt.gz",   # female recipient, MALE DONOR
    "AK2": "GSM4587973_AK2.dge.txt.gz",   # male recipient, FEMALE DONOR
}
METADATA_PATH = "data/GSE151671_Cell_barcode_assignment.xlsx"
OUTFILE       = os.path.join(DATA_DIR, "kidney_unbias.h5ad")

VERBOSE = True
TRACE_ASSIGN_SEX = False

# ───────────── KIDNEY CELL TYPES ──────────────
# Only parenchymal/resident kidney cells (DONOR cells)
_kidney_parenchymal = {
    "PT", "PT1", "PT2", "PG", "LH", "CD", "IC-A", "IC-B", 
    "PTC1", "PTC2", "PTC3", "DVR", "FB1", "FB2", "FB3", "FB4",
    "EC", "PC1", "PC2", "VSMC1", "VSMC2", "VSMC3", "VSMC4",
}

# Immune/circulating cells to EXCLUDE (these are recipient cells)
_immune_circulating = {
    "B", "T", "NK", "Macro", "DC", "Neutrophil", "Monocyte", 
    "Lymphocyte", "Plasma", "Mast"
}

def is_kidney_cell(cell_type: str) -> bool:
    """Check if cell type is kidney parenchymal (donor) cell."""
    if pd.isna(cell_type):
        return False
    c = str(cell_type).upper().strip()
    
    # Direct match or substring match for parenchymal
    is_parenchymal = any(t == c or t in c or c.startswith(t) for t in _kidney_parenchymal)
    
    # Exclude if clearly immune/circulating
    is_immune = any(t in c or c.startswith(t) for t in _immune_circulating)
    
    return is_parenchymal and not is_immune

def assign_donor_sex(sample: str, cell_type: str) -> str:
    """Return donor sex for kidney cells only."""
    if not is_kidney_cell(cell_type):
        return "exclude"  # Mark for removal
    
    if TRACE_ASSIGN_SEX:
        print(f"assign_donor_sex({sample}, {cell_type}) -> kidney cell")
    
    # Sex assignment based on DONOR
    s = sample.upper()
    if s == "HK":  return "female"  # healthy female kidney
    if s == "AK1": return "male"    # male donor kidney  
    if s == "AK2": return "female"  # female donor kidney
    return "unknown"

def read_dge(sample: str, path: str) -> sc.AnnData:
    """Load one DGE file → AnnData with prefixed barcodes."""
    if VERBOSE: print(f"• Loading {sample}")
    df = pd.read_csv(path, sep="\t", index_col=0)
    if df.shape[0] > df.shape[1]:
        df = df.T
    if df.isna().values.any():
        df = df.fillna(0)
    barcodes = [f"{sample}_{bc}" for bc in df.index]
    adata = sc.AnnData(
        X   = sp.csr_matrix(df.values),
        obs = pd.DataFrame(index=barcodes),
        var = pd.DataFrame(index=df.columns),
    )
    adata.obs["sample"] = sample
    return adata

# ────────── 1) LOAD ALL SAMPLES ──────────
adatas = [
    read_dge(s, os.path.join(DATA_DIR, f))
    for s, f in data_files.items()
]

# ────────── 2) CONCATENATE ──────────
adata = sc.concat(
    adatas,
    join       = "outer",
    merge      = "first", 
    fill_value = 0,
)

if not sp.issparse(adata.X):
    adata.X = sp.csr_matrix(adata.X)
if VERBOSE:
    print(f"• Concatenated: {adata.n_obs:,} cells × {adata.n_vars:,} genes")

# ────────── 3) MERGE METADATA ──────────
meta = (
    pd.read_excel(METADATA_PATH, engine="openpyxl")
      .rename(columns=lambda x: x.strip())
      .assign(Cell_barcode=lambda d: d["Cell_barcode"].astype(str).str.strip().str.upper())
      .set_index("Cell_barcode")
)
adata.obs = adata.obs.merge(meta, left_index=True, right_index=True, how="left")

# ────────── 4) FILTER TO KIDNEY CELLS ONLY ──────────
if "Cell_type" not in adata.obs.columns:
    warnings.warn("Metadata lacks 'Cell_type'; cannot filter kidney cells.")
    adata.obs["Cell_type"] = "unknown"

# Assign donor sex and filter
adata.obs["donor_sex"] = adata.obs.apply(
    lambda r: assign_donor_sex(r["sample"], r["Cell_type"]), axis=1
)

# Keep only kidney cells (exclude immune/circulating cells)
before_filter = adata.n_obs
kidney_mask = adata.obs["donor_sex"] != "exclude"
adata = adata[kidney_mask, :].copy()

if VERBOSE:
    print(f"• Filtered to kidney cells: {before_filter:,} → {adata.n_obs:,} cells")
    print("• Donor sex distribution:")
    print(adata.obs["donor_sex"].value_counts(dropna=False))
    print("• Sample distribution:")  
    print(adata.obs["sample"].value_counts())
    print("• Cell types kept:")
    print(adata.obs["Cell_type"].value_counts().head(10))

# ────────── 5) BASIC QC & NORMALISATION ──────────
# Filter genes (min 3 cells)
sc.pp.filter_genes(adata, min_cells=3)

# Filter cells (min 200 genes)
sc.pp.filter_cells(adata, min_genes=200)

# Mitochondrial QC
adata.var["mt"] = adata.var_names.str.startswith(("MT-", "mt-"))
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)

 

# Normalize and log-transform
sc.pp.normalize_total(adata, target_sum=1e4, inplace=True)
sc.pp.log1p(adata)

# Add additional metadata
adata.obs["dataset"] = "kidney_unbias"
adata.obs["filtered_for"] = "donor_kidney_cells_only"

# ────────── 6) SAVE RESULTS ──────────
adata.write(OUTFILE, compression="gzip")

# ────────── SUMMARY ──────────
print("\n" + "="*50)
print("✅ KIDNEY_UNBIAS COMPLETED")
print("="*50)
print(f"Final dataset: {adata.n_obs:,} cells × {adata.n_vars:,} genes")
print(f"Saved to: {OUTFILE}")
print("\nDONOR SEX DISTRIBUTION:")
print(adata.obs["donor_sex"].value_counts(dropna=False))
print("\nSAMPLE DISTRIBUTION:")
print(adata.obs["sample"].value_counts())
print("\nTOP CELL TYPES (kidney only):")
print(adata.obs["Cell_type"].value_counts().head(10))

print(f"\n🎯 FOCUS: Only donor kidney cells included")
print(f"   • HK cells: female donor (healthy)")  
print(f"   • AK1 cells: male donor (transplanted)")
print(f"   • AK2 cells: female donor (transplanted)")
print(f"   • Excluded: All recipient immune/circulating cells")
print(f"   • Ready for unbiased sex-expression analysis!")

• Loading HK
• Loading AK1
• Loading AK2
• Concatenated: 18,000 cells × 21,203 genes


  warn(msg)


• Filtered to kidney cells: 18,000 → 1,404 cells
• Donor sex distribution:
donor_sex
female    1200
male       204
Name: count, dtype: int64
• Sample distribution:
sample
HK     1019
AK1     204
AK2     181
Name: count, dtype: int64
• Cell types kept:
Cell_type
vSMC2    265
CD       235
vSMC1    225
PC1      185
DVR      144
IC-A     129
PC2      124
vSMC3     57
vSMC4     40
Name: count, dtype: int64


OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.



✅ KIDNEY_UNBIAS COMPLETED
Final dataset: 1,404 cells × 14,886 genes
Saved to: GSE151671_RAW/kidney_unbias.h5ad

DONOR SEX DISTRIBUTION:
donor_sex
female    1200
male       204
Name: count, dtype: int64

SAMPLE DISTRIBUTION:
sample
HK     1019
AK1     204
AK2     181
Name: count, dtype: int64

TOP CELL TYPES (kidney only):
Cell_type
vSMC2    265
CD       235
vSMC1    225
PC1      185
DVR      144
IC-A     129
PC2      124
vSMC3     57
vSMC4     40
Name: count, dtype: int64

🎯 FOCUS: Only donor kidney cells included
   • HK cells: female donor (healthy)
   • AK1 cells: male donor (transplanted)
   • AK2 cells: female donor (transplanted)
   • Excluded: All recipient immune/circulating cells
   • Ready for unbiased sex-expression analysis!


## Dataset 2: AML/MLL Bone Marrow Cells - GSE289435 (PRIMARY DATASET)

### Background
This dataset contains bone marrow mononuclear cells from patients with acute myeloid leukemia (AML) carrying mixed lineage leukemia (MLL) rearrangements. **This serves as our primary dataset for sex marker discovery and model training** due to its large sample size, balanced sex distribution, and comprehensive cell type diversity.

### Dataset Characteristics
- **Source**: GSE289435 (10X Genomics)
- **Samples**: 13 patient samples (MLL_14666 to MLL_30886)
- **Sex Distribution**: 6 female, 7 male patients (well-balanced)
- **Cell Type**: Bone marrow mononuclear cells
- **Primary Role**: **Training data and marker gene discovery**
- **Secondary Role**: Self-validation through train/test splits

### Significance as Primary Dataset
This dataset is chosen as the foundation for analysis because:
- **Large sample size**: 13 patients provide statistical power
- **Balanced sex distribution**: Enables robust model training
- **Cell type diversity**: Bone marrow contains multiple hematopoietic lineages
- **High data quality**: 10X Genomics v3 chemistry
- **Clinical relevance**: AML represents a well-characterized malignancy

### Analysis Applications
1. **Feature importance analysis**: Identify most informative sex-specific genes
2. **Model training**: Develop and optimize ML algorithms
3. **Internal validation**: 80/20 train/test splits for initial validation
4. **Cross-dataset training**: Train models for testing on other tissue types

In [12]:
#!/usr/bin/env python3
"""
MLL scRNA-seq merge + QC • GSE289435
─────────────────────────────────────
- Load 10X Genomics datasets for MLL patient samples
- Prefix barcodes with sample name
- Concatenate matrices
- Assign sex labels based on provided information
- Perform QC filtering (min_genes, mitochondrial content)
- Library-size normalize, log1p transform
- Save sparse .h5ad
"""

import os
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import scipy.io as spio
import sys

# ───────────── USER PATHS ─────────────
DATA_DIR = "/Users/haley/Desktop/send_tooo/GSE289435_RAW"
OUTDIR = "/Users/haley/Desktop/send_tooo/human_2"
OUTFILE = os.path.join(OUTDIR, "mll_merged_qc.h5ad")

# Create output directory if it doesn't exist
os.makedirs(OUTDIR, exist_ok=True)

# Sample information with sex labels
sample_info = {
    "MLL_14666": {"gsm_id": "8791432", "sex": "female"},
    "MLL_16703": {"gsm_id": "8791433", "sex": "female"},
    "MLL_17746": {"gsm_id": "8791434", "sex": "male"},
    "MLL_17843": {"gsm_id": "8791435", "sex": "female"},
    "MLL_17844": {"gsm_id": "8791436", "sex": "male"},
    "MLL_28824": {"gsm_id": "8791437", "sex": "male"},
    "MLL_28830": {"gsm_id": "8791438", "sex": "male"},
    "MLL_28855": {"gsm_id": "8791439", "sex": "female"},
    "MLL_29512_PDX": {"gsm_id": "8791440", "sex": "male"},
    "MLL_29532": {"gsm_id": "8791441", "sex": "male"},
    "MLL_29538": {"gsm_id": "8791442", "sex": "female"},
    "MLL_30862": {"gsm_id": "8791443", "sex": "male"},
    "MLL_30886": {"gsm_id": "8791444", "sex": "male"}
}

VERBOSE = True  # one-line progress prints

# ───────────── HELPERS ──────────────
def load_10x_data(sample_id: str, data_dir: str) -> sc.AnnData:
    """Load one 10X Genomics dataset using direct file loading."""
    if VERBOSE: print(f"• Loading {sample_id}")
    
    # Get GSM ID from the sample info
    gsm_id = sample_info[sample_id]["gsm_id"]
    
    # File paths
    matrix_file = os.path.join(data_dir, f"GSM{gsm_id}_{sample_id}.matrix.mtx.gz")
    features_file = os.path.join(data_dir, f"GSM{gsm_id}_{sample_id}.features.tsv.gz")
    barcodes_file = os.path.join(data_dir, f"GSM{gsm_id}_{sample_id}.barcodes.tsv.gz")
    
    # Check if files exist
    for file_path in [matrix_file, features_file, barcodes_file]:
        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
            return None
    
    try:
        # Load matrix
        X = spio.mmread(matrix_file).T.tocsr()
        
        # Load features (genes)
        features = pd.read_csv(features_file, sep='\t', header=None)
        if features.shape[1] >= 2:
            # The second column typically contains gene symbols
            gene_names = features[1].values
        else:
            # If there's only one column, use it as both ID and name
            gene_names = features[0].values
        
        # Make gene names unique
        gene_names_unique = make_unique_names(gene_names)
        
        # Load barcodes
        barcodes = pd.read_csv(barcodes_file, sep='\t', header=None)[0].values
        
        # Create AnnData object
        adata = sc.AnnData(X=X)
        adata.obs_names = [f"{sample_id}_{bc}" for bc in barcodes]
        adata.var_names = gene_names_unique
        
        # Add sample and sex information
        adata.obs["sample"] = sample_id
        adata.obs["sex"] = sample_info[sample_id]["sex"]
        
        return adata
    
    except Exception as e:
        print(f"Error loading {sample_id}: {e}")
        return None

def make_unique_names(names):
    """Make duplicate names unique by appending numbers."""
    name_counts = {}
    unique_names = []
    
    for name in names:
        if name in name_counts:
            name_counts[name] += 1
            unique_names.append(f"{name}_{name_counts[name]}")
        else:
            name_counts[name] = 0
            unique_names.append(name)
    
    return unique_names

# ────────── 1) LOAD ALL SAMPLES ──────────
adatas = []
for sample_id in sample_info.keys():
    adata = load_10x_data(sample_id, DATA_DIR)
    if adata is not None:
        adatas.append(adata)
        if VERBOSE:
            print(f"• Successfully loaded {sample_id}: {adata.n_obs} cells × {adata.n_vars} genes")
    else:
        print(f"❌ Failed to load {sample_id}")

# Check if we have data to work with
if not adatas:
    print("No data was successfully loaded. Exiting.")
    sys.exit(1)

# ────────── 2) CONCATENATE ──────────
if VERBOSE: print("• Concatenating datasets")
if len(adatas) == 1:
    # If only one dataset was loaded, skip concatenation
    adata = adatas[0].copy()
    print("Only one dataset was loaded, skipping concatenation")
else:
    # Concatenate multiple datasets
    adata = sc.concat(
        adatas,
        join="outer",     # Union of genes
        merge="first",
        fill_value=0,     # Fill gaps with 0
    )

# Ensure the data is sparse
if not sp.issparse(adata.X):
    adata.X = sp.csr_matrix(adata.X)

if VERBOSE:
    print(f"• Concatenated: {adata.n_obs:,} cells × {adata.n_vars:,} genes")

# ────────── 3) BASIC QC & FILTERING ──────────
if VERBOSE: print("• Performing QC and filtering")

# Filter genes with low expression
sc.pp.filter_genes(adata, min_cells=3)

# Filter cells with few expressed genes
sc.pp.filter_cells(adata, min_genes=200)

# Identify mitochondrial genes (human MT genes start with MT-)
adata.var["mt"] = adata.var_names.str.startswith("MT-")

# Calculate QC metrics
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)

# Filter cells with high mitochondrial content
max_mito_percent = 5
adata = adata[adata.obs["pct_counts_mt"] < max_mito_percent, :].copy()

# ────────── 4) NORMALIZATION ──────────
if VERBOSE: print("• Normalizing data")
sc.pp.normalize_total(adata, target_sum=1e4, inplace=True)
sc.pp.log1p(adata)

# ────────── 5) WRITE OUTPUT ──────────
if VERBOSE: print(f"• Writing to {OUTFILE}")
adata.write(OUTFILE, compression="gzip")

# ────────── SUMMARY ──────────
print("✅ Finished:")
print(f"   Cells  : {adata.n_obs:,}")
print(f"   Genes  : {adata.n_vars:,}")
print("   Sex    :")
print(adata.obs["sex"].value_counts(dropna=False))
print("   Samples:")
print(adata.obs["sample"].value_counts())

• Loading MLL_14666
• Successfully loaded MLL_14666: 12060 cells × 36601 genes
• Loading MLL_16703
• Successfully loaded MLL_16703: 14986 cells × 36601 genes
• Loading MLL_17746
• Successfully loaded MLL_17746: 9848 cells × 36601 genes
• Loading MLL_17843
• Successfully loaded MLL_17843: 19508 cells × 36601 genes
• Loading MLL_17844
• Successfully loaded MLL_17844: 25922 cells × 36601 genes
• Loading MLL_28824
• Successfully loaded MLL_28824: 8719 cells × 36601 genes
• Loading MLL_28830
• Successfully loaded MLL_28830: 10017 cells × 36601 genes
• Loading MLL_28855
• Successfully loaded MLL_28855: 8059 cells × 36601 genes
• Loading MLL_29512_PDX
• Successfully loaded MLL_29512_PDX: 9360 cells × 36601 genes
• Loading MLL_29532
• Successfully loaded MLL_29532: 8475 cells × 36601 genes
• Loading MLL_29538
• Successfully loaded MLL_29538: 6342 cells × 36601 genes
• Loading MLL_30862
• Successfully loaded MLL_30862: 4600 cells × 36601 genes
• Loading MLL_30886
• Successfully loaded MLL_30886

## Feature Importance Analysis: Sex Marker Discovery from Dataset 2

### Objective
**Primary Goal**: Systematically identify the most robust sex-specific markers using our comprehensive MLL/AML dataset (Dataset 2) as the discovery cohort.

### Rationale for Using Dataset 2
- **Statistical Power**: 13 patients provide sufficient samples for reliable importance estimates
- **Biological Diversity**: Bone marrow contains multiple cell types, identifying universal markers
- **Balanced Design**: Equal sex distribution prevents bias in marker selection
- **Quality Data**: High-quality 10X data ensures accurate gene expression measurements

### Methodology
**Discovery Strategy**: 
- Extract 1/5 random subset for computational efficiency
- Apply 4 complementary ML algorithms (Logistic, SVM, XGBoost, Random Forest)
- Use 5-fold cross-validation for robust importance estimates
- Identify consensus markers selected by ≥3 models

### Expected Outcomes
- **Ranked gene lists**: Most important sex markers for each algorithm
- **Consensus markers**: Genes consistently selected across models
- **Validation of known markers**: Confirm XIST, RPS4Y1, etc. are top-ranked
- **Novel marker discovery**: Identify additional robust sex-specific genes

Here is an example of how to find human marekrs using 5 fold(one seed only here)

In [None]:
#!/usr/bin/env python3
"""
Gene importance analysis for sex classification – GSE289435 MLL dataset
───────────────────────────────────────────────────────────────────────────
- Loads the pre‑processed `mll_merged_qc.h5ad` produced by the merge/QC           
  script (data‑2).                                                                  
- Extracts a random 1/5 subsample of the entire dataset, then performs 5‑fold stratified CV
  using four ML pipelines, each bundling `StandardScaler()` + model:                                                      
    – Logistic Regression                                                          
    – Linear‑kernel SVC (probability = True)                                       
    – XGBoost                                                                       
    – Random Forest                                                                
- Collects and aggregates feature importances, writes per‑model CSV/plots,         
  consensus gene list, ROC+PR curves, and final metrics.                           

Edit the *PATHS* block if your directory structure changes.                         
"""

from __future__ import annotations
import os
from pathlib import Path
import warnings

import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier

from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    confusion_matrix,
    roc_curve,
    precision_recall_curve,
    roc_auc_score,
    average_precision_score,
)

# ════════════════════ PATHS ════════════════════════════════
DATA_DIR  = Path("/Users/haley/Desktop/send_tooo/human_2").expanduser()
H5AD_FILE = DATA_DIR / "mll_merged_qc.h5ad"
OUT_DIR   = DATA_DIR / "sex_marker_analysis_subsample"
FEATURE_DIR = OUT_DIR / "feature_importance"
FEATURE_DIR.mkdir(parents=True, exist_ok=True)

# ════════════════════ HELPERS ══════════════════════════════

def make_pipe(model):
    """Return a sklearn Pipeline: StandardScaler → model."""
    return Pipeline([
        ("scaler", StandardScaler()),
        ("clf",    model),
    ])


def unify_gene_symbols(adata: sc.AnnData) -> sc.AnnData:
    adata.var_names = adata.var_names.astype(str).str.upper()
    if not adata.var_names.is_unique:
        adata.var_names_make_unique()
    return adata


def extract_sex_labels(adata: sc.AnnData):
    if "sex" not in adata.obs:
        raise KeyError("'sex' column missing in AnnData.obs")
    y = (adata.obs["sex"].astype(str).str.lower().str.strip() == "male").astype(int)
    return adata.copy(), y.values


def to_dataframe(adata: sc.AnnData) -> pd.DataFrame:
    X = adata.X.A if sp.issparse(adata.X) else adata.X
    df = pd.DataFrame(X, index=adata.obs_names, columns=adata.var_names)
    nonconst = (df != df.iloc[0]).any()
    dropped = (~nonconst).sum()
    if dropped:
        warnings.warn(f"Dropping {dropped} constant genes")
        df = df.loc[:, nonconst]
    return df

# ════════════════════ 1) LOAD DATA ═════════════════════════
print("[STEP] Reading mll_merged_qc.h5ad …")
adata = sc.read_h5ad(H5AD_FILE)
adata = unify_gene_symbols(adata)
adata, y = extract_sex_labels(adata)
print(f"▶ Dataset: {adata.n_obs:,} cells ─ ♀ {(y==0).sum()}  ♂ {(y==1).sum()}")

# ════════════════════ 2) EXTRACT 1/5 RANDOM SUBSAMPLE ═════════════════
# Subsample 1/5 of the data while maintaining sex distribution
print("[STEP] Extracting 1/15 random subsample of data...")
subsample_size = len(adata) // 15

indices = np.arange(len(adata))
_, subsample_indices, _, y_subsample = train_test_split(
    indices, y, test_size=subsample_size/len(adata), 
    stratify=y, random_state=42
)

# Create subsampled AnnData object
adata_subsample = adata[subsample_indices].copy()
y_subsample = y[subsample_indices]
print(f"▶ Subsample: {adata_subsample.n_obs:,} cells ─ ♀ {(y_subsample==0).sum()}  ♂ {(y_subsample==1).sum()}")
print(f"▶ Subsampling ratio: {adata_subsample.n_obs/adata.n_obs:.1%} of original data")

# ════════════════════ 3) EXPRESSION MATRIX ═════════════════
X_df = to_dataframe(adata_subsample)
print(f"▶ Matrix: {X_df.shape[0]} cells × {X_df.shape[1]} genes")

# ════════════════════ 4) PIPELINES ═════════════════════════
PIPES = {
    "LogisticRegression": make_pipe(LogisticRegression(max_iter=1000, random_state=42)),
    "LinearSVC":          make_pipe(SVC(kernel="linear", probability=True, random_state=42)),
    "XGBoost":            make_pipe(XGBClassifier(
        eval_metric="logloss", random_state=42,
        n_estimators=100, learning_rate=0.05, max_depth=10)),
    "RandomForest":       make_pipe(RandomForestClassifier(max_depth=10, random_state=42)),
}

feat_imps = {name: [] for name in PIPES}
preds    = {name: {"y_true": [], "y_prob": []} for name in PIPES}

# ════════════════════ 5) 5‑FOLD CV ON SUBSAMPLE ════════════════════
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for fold, (tr, te) in enumerate(skf.split(X_df, y_subsample), 1):
    print(f"\n[CV] Fold {fold}/5")
    X_tr, X_te = X_df.iloc[tr], X_df.iloc[te]
    y_tr, y_te = y_subsample[tr], y_subsample[te]
    
    print(f"   ▶ Training: {len(X_tr)} cells, Testing: {len(X_te)} cells")

    for name, pipe in PIPES.items():
        print(f"   – {name}…", end="", flush=True)
        pipe.fit(X_tr, y_tr)

        # underlying estimator
        est = pipe.named_steps["clf"]
        if hasattr(est, "coef_"):
            imp = np.abs(est.coef_[0])
        elif hasattr(est, "feature_importances_"):
            imp = est.feature_importances_
        else:
            imp = np.zeros(X_tr.shape[1])
        feat_imps[name].append(pd.DataFrame({"Feature": X_tr.columns, "Importance": imp}))

        y_prob = pipe.predict_proba(X_te)[:, 1]
        preds[name]["y_true"].extend(y_te)
        preds[name]["y_prob"].extend(y_prob)

        y_pred = (y_prob >= 0.5).astype(int)
        acc = accuracy_score(y_te, y_pred)
        f1  = f1_score(y_te, y_pred)
        auc = roc_auc_score(y_te, y_prob)
        ap  = average_precision_score(y_te, y_prob)
        print(f" Acc={acc:.3f} F1={f1:.3f} AUROC={auc:.3f} AUPRC={ap:.3f}")

# ════════════════════ 6) IMPORTANCE AGGREGATION ═══════════
print("\n[STEP] Aggregating feature importances …")
agg, top20 = {}, {}
for name, dfs in feat_imps.items():
    mean_imp = (pd.concat(dfs)
                  .groupby("Feature")["Importance"].mean()
                  .sort_values(ascending=False)
                  .reset_index())
    mean_imp["Rank"] = mean_imp["Importance"].rank(method="dense", ascending=False).astype(int)
    agg[name] = mean_imp
    csv_path = FEATURE_DIR / f"{name}_feature_importances.csv"
    mean_imp.to_csv(csv_path, index=False)

    top20[name] = mean_imp.head(20)["Feature"].tolist()
    plt.figure(figsize=(9, 7))
    sns.barplot(data=mean_imp.head(20), y="Feature", x="Importance", palette="viridis")
    plt.title(f"Top‑20 genes – {name} (MLL Dataset)")
    plt.tight_layout()
    plt.savefig(FEATURE_DIR / f"{name}_top20.png", dpi=300, bbox_inches="tight")
    plt.close()

# consensus list
counts = {}
for genes in top20.values():
    for g in genes:
        counts[g] = counts.get(g, 0) + 1
consensus = (pd.Series(counts, name="Models_Count")
               .sort_values(ascending=False)
               .reset_index().rename(columns={"index": "Gene"}))
consensus.to_csv(FEATURE_DIR / "consensus_top_genes.csv", index=False)

# ════════════════════ 7) ROC & PR CURVES ═════════════════
print("\n[STEP] Generating ROC and PR curves …")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# ROC Curves
for name, p in preds.items():
    fpr, tpr, _ = roc_curve(p["y_true"], p["y_prob"])
    auc = roc_auc_score(p["y_true"], p["y_prob"])
    ax1.plot(fpr, tpr, label=f"{name} (AUC = {auc:.3f})")

ax1.plot([0, 1], [0, 1], 'k--')
ax1.set_xlabel('False Positive Rate')
ax1.set_ylabel('True Positive Rate')
ax1.set_title('ROC Curves')
ax1.legend()

# PR Curves
for name, p in preds.items():
    precision, recall, _ = precision_recall_curve(p["y_true"], p["y_prob"])
    ap = average_precision_score(p["y_true"], p["y_prob"])
    ax2.plot(recall, precision, label=f"{name} (AP = {ap:.3f})")

# Calculate baseline for PR curve (proportion of positive class)
baseline = sum(preds[list(preds.keys())[0]]["y_true"]) / len(preds[list(preds.keys())[0]]["y_true"])
ax2.plot([0, 1], [baseline, baseline], 'k--', label=f'Baseline ({baseline:.3f})')

ax2.set_xlabel('Recall')
ax2.set_ylabel('Precision')
ax2.set_title('Precision-Recall Curves')
ax2.legend()

plt.tight_layout()
plt.savefig(OUT_DIR / "roc_pr_curves.png", dpi=300, bbox_inches="tight")
plt.close()

# ════════════════════ 8) FINAL METRICS ═════════════════
print("\n[STEP] Summarizing metrics …")
metrics = []
for name, p in preds.items():
    y_true = np.array(p["y_true"])
    y_prob = np.array(p["y_prob"])
    y_pred = (y_prob >= 0.5).astype(int)
    
    metrics.append({
        "Model": name,
        "Accuracy": accuracy_score(y_true, y_pred),
        "F1": f1_score(y_true, y_pred),
        "AUROC": roc_auc_score(y_true, y_prob),
        "AUPRC": average_precision_score(y_true, y_prob)
    })

metrics_df = pd.DataFrame(metrics)
metrics_df.to_csv(OUT_DIR / "model_metrics.csv", index=False)

# Print a nice summary table
print("\nFinal 5-fold CV metrics:\n" + "-" * 50)
print(metrics_df.to_string(index=False, float_format=lambda x: f"{x:.4f}"))
print("-" * 50)
print(f"\nResults saved to {OUT_DIR}")

[STEP] Reading mll_merged_qc.h5ad …
▶ Dataset: 87,171 cells ─ ♀ 44716  ♂ 42455
[STEP] Extracting 1/15 random subsample of data...
▶ Subsample: 5,811 cells ─ ♀ 2981  ♂ 2830
▶ Subsampling ratio: 6.7% of original data




▶ Matrix: 5811 cells × 22980 genes

[CV] Fold 1/5
   ▶ Training: 4648 cells, Testing: 1163 cells
   – LogisticRegression… Acc=0.983 F1=0.982 AUROC=0.999 AUPRC=0.999
   – LinearSVC… Acc=0.990 F1=0.989 AUROC=0.999 AUPRC=0.999
   – XGBoost… Acc=0.982 F1=0.981 AUROC=0.998 AUPRC=0.999
   – RandomForest… Acc=0.936 F1=0.934 AUROC=0.987 AUPRC=0.988

[CV] Fold 2/5
   ▶ Training: 4649 cells, Testing: 1162 cells
   – LogisticRegression… Acc=0.985 F1=0.985 AUROC=0.997 AUPRC=0.998
   – LinearSVC… Acc=0.982 F1=0.981 AUROC=0.997 AUPRC=0.997
   – XGBoost… Acc=0.985 F1=0.985 AUROC=0.998 AUPRC=0.998
   – RandomForest… Acc=0.933 F1=0.931 AUROC=0.986 AUPRC=0.985

[CV] Fold 3/5
   ▶ Training: 4649 cells, Testing: 1162 cells
   – LogisticRegression… Acc=0.983 F1=0.982 AUROC=0.999 AUPRC=0.998
   – LinearSVC… Acc=0.981 F1=0.981 AUROC=0.998 AUPRC=0.998
   – XGBoost… Acc=0.985 F1=0.985 AUROC=0.999 AUPRC=0.999
   – RandomForest… Acc=0.940 F1=0.937 AUROC=0.987 AUPRC=0.987

[CV] Fold 4/5
   ▶ Training: 4649 cells,


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(data=mean_imp.head(20), y="Feature", x="Importance", palette="viridis")

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(data=mean_imp.head(20), y="Feature", x="Importance", palette="viridis")

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(data=mean_imp.head(20), y="Feature", x="Importance", palette="viridis")

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(data=mean_imp.head(20), y="Feature", x="Importance", palette


[STEP] Generating ROC and PR curves …

[STEP] Summarizing metrics …

Final 5-fold CV metrics:
--------------------------------------------------
             Model  Accuracy     F1  AUROC  AUPRC
LogisticRegression    0.9802 0.9797 0.9977 0.9976
         LinearSVC    0.9804 0.9798 0.9969 0.9970
           XGBoost    0.9845 0.9839 0.9984 0.9985
      RandomForest    0.9368 0.9345 0.9857 0.9859
--------------------------------------------------

Results saved to /Users/haley/Desktop/send_tooo/human_2/sex_marker_analysis_subsample


In [11]:
# ----------------------------------------
#  right after you build `counts` or `consensus`
# ----------------------------------------

# Option 1 – work directly from the `counts` dict
genes_over_two = [g for g, c in counts.items() if c > 2]
print("Genes selected by ≥3 models:", genes_over_two)

# Option 2 – use the consensus DataFrame you just wrote
genes_over_two = (
    consensus.loc[consensus["Models_Count"] > 2, "Gene"]
    .tolist()
)
print("Genes selected by ≥3 models:", genes_over_two)

# (optional) save to disk
(pd.Series(genes_over_two, name="Gene")
   .to_csv(FEATURE_DIR / "genes_selected_by_3plus_models.csv",
           index=False))


Genes selected by ≥3 models: ['RPS4Y1', 'EIF1AY', 'XIST', 'DDX3Y', 'UTY', 'KDM5D', 'IFIT3', 'IFIT2', 'RPS4X', 'RPL29']
Genes selected by ≥3 models: ['RPS4Y1', 'IFIT2', 'XIST', 'DDX3Y', 'EIF1AY', 'RPS4X', 'RPL29', 'IFIT3', 'UTY', 'KDM5D']


## Gene Selection Robustness Analysis Across Random Seeds

![Gene Selection Across Seeds](/Users/haley/Downloads/CellSexID-main/gene_seed_heatmap.png)

**Figure**: Cross-validation analysis of gene selection consistency across five random seeds (42, 123, 456, 789, 999) using 10% random subsampling of the full MLL/AML dataset (Dataset 2). This methodological approach addresses computational constraints of the large dataset while maintaining statistical rigor through random sampling and multiple seed validation.

### Methodological Approach:
- **Dataset Sampling**: 10% random subset of the complete dataset to manage computational complexity
- **Cross-Validation**: Five different random seeds to assess selection robustness  
- **Rationale**: Random subsampling at 10% provides a computationally feasible yet statistically sound approach for large-scale marker discovery

### Key Findings:

**Highly Robust Markers (4/4 model consistency):**
- DDX3Y, EIF1AY, IFIT2, UTY, RPS4Y1 - demonstrate exceptional stability across all random seeds

**Moderately Robust Markers (3-4/4 model consistency):**  
- RPS4X, XIST, IFIT3, KDM5D - show strong but not perfect consistency

**Seed-Sensitive Markers:**
- OASL: Selected only with seeds 456 and 789
- USP9Y: Selected exclusively with seed 42

### Conclusion:
This cross-validation analysis identified **9 highly consistent sex prediction markers** that demonstrate robust selection patterns independent of random seed variation. The methodological choice of 10% random subsampling represents a pragmatic yet scientifically sound approach to marker discovery in large-scale genomic datasets, successfully identifying stable biomarkers while managing computational resources effectively.

These 9 robust markers form the foundation of our human sex prediction gene signature, validated through this rigorous cross-seed consistency analysis.

## Dataset 3: Thymic Epithelial Cells - GSE262749

### Background
Medullary thymic epithelial cells represent a highly specialized cell type involved in T-cell selection. This dataset tests model performance on tissue-specific epithelial cells.

### Dataset Characteristics
- **Source**: GSE262749 (10X Genomics)
- **Samples**: 5 donors (DonorA-E)
- **Sex Distribution**: 3 female, 2 male donors
- **Cell Type**: Medullary thymic epithelial cells
- **Role**: Validation on specialized epithelial tissue 

### Analytical Challenge
**Tissue Specificity**: Thymic epithelial cells have unique expression profiles
**Expected Performance**: Tests whether bone marrow-derived markers generalize to epithelial contexts 

	Thymic mimetic cells in humans [scRNA-seq]； medullary thymic epithelial cells

In [4]:
#!/usr/bin/env python3
"""
scRNA-seq merge + QC • GSE262749
────────────────────────────────
• Load 10X Genomics matrices for five donor samples
• Prefix barcodes with sample ID
• Concatenate into a single AnnData
• Add sex labels
• QC:
   – remove low-quality genes/cells
   – compute % mitochondrial reads
   – discard cells with high mito %
   – **drop mitochondrial genes from the matrix**
• Library-size normalise, log1p transform
• Save compressed .h5ad
"""

import os
import sys
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import scipy.io as spio

# ───────────── USER PATHS ─────────────
DATA_DIR = "/Users/haley/Desktop/send_tooo/GSE262749_RAW"
OUTDIR   = "/Users/haley/Desktop/send_tooo/human_3"
OUTFILE  = os.path.join(OUTDIR, "donor_merged_qc.h5ad")
os.makedirs(OUTDIR, exist_ok=True)

# ───────────── SAMPLE METADATA ─────────────
sample_info = {
    "DonorA": {"gsm_id": "8178134", "sex": "female"},
    "DonorB": {"gsm_id": "8178135", "sex": "male"},
    "DonorC": {"gsm_id": "8178136", "sex": "female"},
    "DonorD": {"gsm_id": "8178137", "sex": "male"},
    "DonorE": {"gsm_id": "8178138", "sex": "female"},
}

VERBOSE = True

# ───────────── HELPERS ──────────────
def make_unique(names):
    """Ensure gene names are unique (Scanpy requires this)."""
    counts, unique = {}, []
    for n in names:
        if n in counts:
            counts[n] += 1
            unique.append(f"{n}_{counts[n]}")
        else:
            counts[n] = 0
            unique.append(n)
    return unique


def load_10x(sample_id: str) -> sc.AnnData | None:
    """Load a single 10X dataset (matrix.mtx + features + barcodes)."""
    gsm = sample_info[sample_id]["gsm_id"]
    mat_f = os.path.join(DATA_DIR, f"GSM{gsm}_{sample_id}_matrix.mtx.gz")
    feat_f = os.path.join(DATA_DIR, f"GSM{gsm}_{sample_id}_features.tsv.gz")
    bc_f  = os.path.join(DATA_DIR, f"GSM{gsm}_{sample_id}_barcodes.tsv.gz")

    for fp in (mat_f, feat_f, bc_f):
        if not os.path.exists(fp):
            print(f"❌ File not found: {fp}")
            return None

    if VERBOSE:
        print(f"• Loading {sample_id}")

    try:
        X = spio.mmread(mat_f).T.tocsr()

        genes_df = pd.read_csv(feat_f, sep="\t", header=None)
        gene_names = genes_df[1].values if genes_df.shape[1] >= 2 else genes_df[0].values
        gene_names = make_unique(gene_names)

        barcodes = pd.read_csv(bc_f, sep="\t", header=None)[0].values
        ad = sc.AnnData(X, dtype=np.int32)
        ad.obs_names = [f"{sample_id}_{bc}" for bc in barcodes]
        ad.var_names = gene_names
        ad.obs["sample"] = sample_id
        ad.obs["sex"]    = sample_info[sample_id]["sex"]
        return ad

    except Exception as e:
        print(f"❌ Error loading {sample_id}: {e}")
        return None


# ────────── 1) LOAD ALL SAMPLES ──────────
adatas = [d for s in sample_info for d in (load_10x(s),) if d is not None]

if not adatas:
    sys.exit("No data loaded — exiting.")

if VERBOSE:
    for ad in adatas:
        print(f"   ↳ {ad.obs['sample'][0]}: {ad.n_obs:,} cells  ×  {ad.n_vars:,} genes")

# ────────── 2) CONCATENATE ──────────
if VERBOSE:
    print("• Concatenating samples")

adata = adatas[0] if len(adatas) == 1 else sc.concat(adatas, join="outer", merge="first", fill_value=0)

if not sp.issparse(adata.X):
    adata.X = sp.csr_matrix(adata.X)

if VERBOSE:
    print(f"   Total: {adata.n_obs:,} cells  ×  {adata.n_vars:,} genes")

# ────────── 3) QC FILTERING ──────────
if VERBOSE:
    print("• QC filtering")

# a) gene / cell minimums
sc.pp.filter_genes(adata, min_cells=3)
sc.pp.filter_cells(adata, min_genes=200)

# b) percent mitochondrial
adata.var["mt"] = adata.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)

adata = adata[adata.obs["pct_counts_mt"] < 5].copy()

# c) REMOVE MITOCHONDRIAL GENES
n_mt = int(adata.var["mt"].sum())
adata = adata[:, ~adata.var["mt"]].copy()

if VERBOSE:
    print(f"   Removed {n_mt} mitochondrial genes; {adata.n_vars:,} genes remain.")

# ────────── 4) NORMALISATION ──────────
if VERBOSE:
    print("• Normalising & log-transforming")

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

# ────────── 5) SAVE ──────────
if VERBOSE:
    print(f"• Saving → {OUTFILE}")

adata.write(OUTFILE, compression="gzip")

# ────────── SUMMARY ──────────
print("✅ Finished")
print(f"   Cells  : {adata.n_obs:,}")
print(f"   Genes  : {adata.n_vars:,}")
print("   Sex    :")
print(adata.obs["sex"].value_counts(dropna=False))
print("   Samples:")
print(adata.obs["sample"].value_counts())


• Loading DonorA




• Loading DonorB




• Loading DonorC




• Loading DonorD




• Loading DonorE


  print(f"   ↳ {ad.obs['sample'][0]}: {ad.n_obs:,} cells  ×  {ad.n_vars:,} genes")


   ↳ DonorA: 5,885 cells  ×  36,616 genes
   ↳ DonorB: 8,215 cells  ×  36,616 genes
   ↳ DonorC: 6,126 cells  ×  36,616 genes
   ↳ DonorD: 6,633 cells  ×  36,616 genes
   ↳ DonorE: 9,361 cells  ×  36,616 genes
• Concatenating samples
   Total: 36,220 cells  ×  36,616 genes
• QC filtering


OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


   Removed 13 mitochondrial genes; 32,399 genes remain.
• Normalising & log-transforming
• Saving → /Users/haley/Desktop/send_tooo/human_3/donor_merged_qc.h5ad
✅ Finished
   Cells  : 25,379
   Genes  : 32,399
   Sex    :
sex
female    13508
male      11871
Name: count, dtype: int64
   Samples:
sample
DonorD    6240
DonorB    5631
DonorA    5190
DonorE    4355
DonorC    3963
Name: count, dtype: int64


## Cross-Dataset Validation: Dataset 2 (MLL/AML) → Dataset 3 (Thymic)

### Experimental Design
This analysis tests the generalizability of sex prediction models trained on bone marrow cells (Dataset 2) when applied to thymic epithelial cells (Dataset 3). This represents one of the most challenging cross-dataset validations due to the biological distance between hematopoietic and epithelial cell types.

### Biological Rationale
**Training Context**: Bone marrow mononuclear cells (diverse hematopoietic lineages)  
**Testing Context**: Medullary thymic epithelial cells (specialized epithelial tissue)  
**Challenge**: Test whether sex markers discovered in blood/immune cells generalize to epithelial contexts

### Methodological Approach
**Selected Marker Panel**: 9 carefully curated sex-specific genes identified from prior feature importance analysis 

**Training Strategy**:
- 4 ML algorithms: Logistic Regression, SVM, XGBoost, Random Forest

**Evaluation Framework**:
- Cross-dataset performance metrics (AUROC, AUPRC, accuracy)
- ROC and Precision-Recall curve generation
- Model comparison across algorithms

 

In [5]:
#!/usr/bin/env python3
"""
Sex classification using MLL data (human_2) subsample as training and Donor data (human_3) as testing
Using 9 selected marker genes: RPS4Y1, EIF1AY, XIST, DDX3Y, UTY, KDM5D, IFIT3, IFIT2, RPS4X
Models: LogisticRegression, Linear-SVC, XGBoost, Random-Forest
"""

# ─── imports ───────────────────────────────────────────────
import os, pathlib, warnings
import numpy as np, pandas as pd, scanpy as sc, scipy.sparse as sp
import matplotlib.pyplot as plt

from sklearn.impute            import SimpleImputer
from sklearn.preprocessing     import StandardScaler
from sklearn.pipeline          import Pipeline
from sklearn.linear_model      import LogisticRegression
from sklearn.svm               import SVC
from sklearn.ensemble          import RandomForestClassifier
from sklearn.model_selection   import train_test_split
from xgboost                   import XGBClassifier
from sklearn.metrics           import (
    accuracy_score, f1_score, roc_auc_score, average_precision_score,
    confusion_matrix, roc_curve, precision_recall_curve
)

# ─── selected marker panel ──────────────────────────────────────────
SELECTED_MARKERS = [
    "RPS4Y1", 
    "EIF1AY", 
    "XIST", 
    "DDX3Y", 
    "UTY", 
    "KDM5D", 
    "IFIT3", 
    "IFIT2", 
    "RPS4X"
]

# ─── alias dictionary for gene name mapping ───────────────────
alias_to_official = {
    "XIST":"Xist", "RPS27RT":"Rps27rt", "DDX3Y":"Ddx3y", "RPL35":"Rpl35",
    "EIF2S3Y":"Eif2s3y", "EIF2S3L":"Eif2s3y", "GM42418":"Gm42418", "UBA52":"Uba52",
    "RPL36A-PS1":"Rpl36a-ps1", "KDM5D":"Kdm5d", "JARID1D":"Kdm5d", "WDR89":"Wdr89",
    "UTY":"Uty", "LARS2":"Lars2", "AY036118":"AY036118", "RPL9-PS6":"Rpl9-ps6", "RPS27":"Rps27",
    "RPS4Y1":"RPS4Y1", "EIF1AY":"EIF1AY", "GNLY":"GNLY", "IFIT3":"IFIT3", "IFIT2":"IFIT2", 
    "RPS4X":"RPS4X", "RPL29":"RPL29", 
    # Keep human gene names as-is since both datasets are human
}

# ─── file paths ────────────────────────────────────────────
DATA_DIR = "/Users/haley/Desktop/send_tooo"
MLL_H5AD = os.path.join(DATA_DIR, "human_2/mll_merged_qc.h5ad")
DONOR_H5AD = os.path.join(DATA_DIR, "human_3/donor_merged_qc.h5ad")
OUT_DIR = pathlib.Path("/Users/haley/Downloads/CellSexID-main/result")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ─── helper functions ────────────────────────────────────────
def unify_gene_symbols(adata):
    """Normalize gene symbols using alias dictionary"""
    if not isinstance(adata.var_names, pd.Index):
        return adata
    
    # Create a mapping dictionary for renaming
    rename_dict = {}
    for gene in adata.var_names:
        # Check for aliases (case insensitive)
        gene_upper = gene.upper()
        if gene_upper in alias_to_official:
            rename_dict[gene] = alias_to_official[gene_upper]
    
    # Rename genes if aliases are found
    if rename_dict:
        print(f"Renaming {len(rename_dict)} genes using alias dictionary")
        adata.var_names = [rename_dict.get(g, g) for g in adata.var_names]
        
    # Make variable names unique if needed
    if not adata.var_names.is_unique:
        print("Making gene names unique")
        adata.var_names_make_unique()
        
    return adata

def extract_sex_labels(adata):
    """Extract standardized sex labels (0=female, 1=male)"""
    if "sex" not in adata.obs:
        raise ValueError("'sex' column not found in AnnData.")

    sex = (
        adata.obs["sex"]
          .astype(str).str.strip().str.lower()
          .map({"female": 0, "male": 1})
    )
    mask = sex.notna()
    return adata[mask].copy(), sex[mask].astype(int).values

def make_pipe(clf):
    """Create a preprocessing pipeline for a classifier"""
    steps = [("imp", SimpleImputer(strategy="median"))]
    if isinstance(clf, (LogisticRegression, SVC)):
        steps.append(("sc", StandardScaler(with_mean=False)))
    steps.append(("clf", clf))
    return Pipeline(steps)

def extract_marker_matrix(adata, markers):
    """Extract marker gene expression matrix from AnnData"""
    # Convert var_names to lowercase for case-insensitive matching
    var_lower = {g.lower(): g for g in adata.var_names}
    
    # Find markers present in dataset (case-insensitive)
    present = [var_lower[g.lower()] for g in markers if g.lower() in var_lower]
    
    if len(present) < 2:
        raise ValueError(f"Fewer than 2 marker genes present in dataset. Found: {present}")
    
    # Extract expression matrix as DataFrame
    X_df = pd.DataFrame(
        adata[:, present].X.A if sp.issparse(adata.X) else adata[:, present].X,
        index=adata.obs_names,
        columns=present,
    )
    
    # Drop constant columns that don't provide information
    nonconst = (X_df != X_df.iloc[0]).any()
    if (~nonconst).any():
        dropped = X_df.columns[~nonconst].tolist()
        warnings.warn(f"Dropping constant marker(s): {dropped}")
        X_df = X_df.loc[:, nonconst]
        present = X_df.columns.tolist()
    
    if len(present) < 2:
        raise ValueError("Need ≥2 informative markers after filtering.")
    
    print(f"Markers used ({len(present)}): {present}")
    
    return X_df, present

# ═════════ 1) Load datasets ══════════════════════════
print("Loading MLL dataset...")
mll_adata = sc.read_h5ad(MLL_H5AD)
mll_adata = unify_gene_symbols(mll_adata)
mll_adata, mll_y = extract_sex_labels(mll_adata)
print(f"MLL dataset: {mll_adata.n_obs:,} cells  "
      f"(♀ {(mll_y==0).sum()}  ♂ {(mll_y==1).sum()})")

print("\nLoading Donor dataset (test data)...")
test_adata = sc.read_h5ad(DONOR_H5AD)
test_adata = unify_gene_symbols(test_adata)
test_adata, y_test = extract_sex_labels(test_adata)
print(f"Donor dataset: {test_adata.n_obs:,} cells  "
      f"(♀ {(y_test==0).sum()}  ♂ {(y_test==1).sum()})")

# ═════════ 2) Subsample MLL dataset (1/15) ══════════════════
print("\nExtracting 1/15 random subsample from MLL dataset...")
subsample_size = len(mll_adata) // 15

indices = np.arange(len(mll_adata))
_, subsample_indices, _, y_subsample = train_test_split(
    indices, mll_y, test_size=subsample_size/len(mll_adata), 
    stratify=mll_y, random_state=42
)

# Create subsampled AnnData object for training
train_adata = mll_adata[subsample_indices].copy()
y_train = mll_y[subsample_indices]
print(f"Training subsample: {train_adata.n_obs:,} cells  "
      f"(♀ {(y_train==0).sum()}  ♂ {(y_train==1).sum()})")
print(f"Subsampling ratio: {train_adata.n_obs/mll_adata.n_obs:.1%} of original MLL data")

# ═════════ 3) Extract marker matrices ══════════════════
print("\nExtracting selected marker genes from training data...")
X_train_df, train_markers = extract_marker_matrix(train_adata, SELECTED_MARKERS)

print("\nExtracting selected marker genes from test data...")
X_test_df, test_markers = extract_marker_matrix(test_adata, SELECTED_MARKERS)

# Find common markers between train and test sets
common_markers = sorted(set(train_markers) & set(test_markers))
if len(common_markers) < 2:
    raise ValueError(f"Fewer than 2 common marker genes between datasets. Found: {common_markers}")

print(f"\nCommon markers used for training and testing ({len(common_markers)}): {common_markers}")

# Use only common markers
X_train = X_train_df[common_markers].values
X_test = X_test_df[common_markers].values

# ═════════ 4) Define models ══════════════════════════
pipelines = {
    "LogisticRegression": make_pipe(LogisticRegression(max_iter=1000, random_state=42)),
    "LinearSVC": make_pipe(SVC(kernel="linear", probability=True, random_state=42)),
    "XGBoost": make_pipe(XGBClassifier(
        eval_metric="logloss", random_state=42,
        n_estimators=100, learning_rate=0.05, max_depth=10)),
    "RandomForest": make_pipe(RandomForestClassifier(max_depth=10, random_state=42)),
}

# Set up for curve data collection and plotting
curve_data_roc = []
curve_data_pr = []
colors = {
    "LogisticRegression": "blue",
    "LinearSVC": "red",
    "XGBoost": "green",
    "RandomForest": "purple"
}

# Create figures for plotting
fig_roc, ax_roc = plt.subplots(figsize=(10, 8))
fig_pr, ax_pr = plt.subplots(figsize=(10, 8))

# ═════════ 5) Train and evaluate models ══════════════════
print("\n" + "="*50)
print("Training and evaluating models using selected genes")
print("="*50)

results = []
for name, model in pipelines.items():
    print(f"\n=== {name} ===")
    model.fit(X_train, y_train)

    # 1) Train performance
    p_tr = model.predict(X_train)
    prob_tr = model.predict_proba(X_train)[:, 1]
    tr_acc = accuracy_score(y_train, p_tr)
    tr_f1 = f1_score(y_train, p_tr)
    tr_roc = roc_auc_score(y_train, prob_tr)
    tr_pr = average_precision_score(y_train, prob_tr)
    print(f" TRAIN → Acc={tr_acc:.4f}, F1={tr_f1:.4f}, AUROC={tr_roc:.4f}, AUPRC={tr_pr:.4f}")

    # 2) Test performance
    p_test = model.predict(X_test)
    prob_test = model.predict_proba(X_test)[:, 1]
    test_acc = accuracy_score(y_test, p_test)
    test_f1 = f1_score(y_test, p_test)
    test_roc = roc_auc_score(y_test, prob_test)
    test_pr = average_precision_score(y_test, prob_test)
    print(f" TEST → Acc={test_acc:.4f}, F1={test_f1:.4f}, AUROC={test_roc:.4f}, AUPRC={test_pr:.4f}")
    print("  Confusion Matrix:")
    print(confusion_matrix(y_test, p_test))
    
    results.append({
        "Model": name,
        "Train_Acc": tr_acc, "Train_F1": tr_f1,
        "Train_AUROC": tr_roc, "Train_AUPRC": tr_pr,
        "Test_Acc": test_acc, "Test_F1": test_f1,
        "Test_AUROC": test_roc, "Test_AUPRC": test_pr,
    })
    
    # Calculate ROC curve points
    fpr, tpr, _ = roc_curve(y_test, prob_test)
    roc_df = pd.DataFrame({"model": name, "fpr": fpr, "tpr": tpr})
    curve_data_roc.append(roc_df)
    
    # Calculate PR curve points
    precision, recall, _ = precision_recall_curve(y_test, prob_test)
    pr_df = pd.DataFrame({"model": name, "precision": precision, "recall": recall})
    curve_data_pr.append(pr_df)
    
    # Plot ROC curve
    ax_roc.plot(fpr, tpr, lw=2, color=colors[name], 
             label=f'{name} (area = {test_roc:.3f})')
    
    # Plot PR curve
    ax_pr.plot(recall, precision, lw=2, color=colors[name], 
            label=f'{name} (area = {test_pr:.3f})')

# ═════════ 6) Save results ══════════════════════════
# Combine and save curve data
all_roc_data = pd.concat(curve_data_roc, ignore_index=True)
all_pr_data = pd.concat(curve_data_pr, ignore_index=True)

all_roc_data.to_csv(OUT_DIR / "human2_to_human3_selected_auroc.csv", index=False)
all_pr_data.to_csv(OUT_DIR / "human2_to_human3_selected_auprc.csv", index=False)

# Finalize and save ROC plot
ax_roc.plot([0, 1], [0, 1], 'k--', lw=2)
ax_roc.set_xlim([0.0, 1.0])
ax_roc.set_ylim([0.0, 1.05])
ax_roc.set_xlabel('False Positive Rate')
ax_roc.set_ylabel('True Positive Rate')
ax_roc.set_title('Human2 (MLL) → Human3 (Donor): ROC Curves (Selected Genes)')
ax_roc.legend(loc="lower right")
ax_roc.grid(True, linestyle='--', alpha=0.7)
fig_roc.tight_layout()
fig_roc.savefig(OUT_DIR / "human2_to_human3_selected_roc_curves.png", dpi=300, bbox_inches='tight')

# Finalize and save PR plot
ax_pr.set_xlabel('Recall')
ax_pr.set_ylabel('Precision')
ax_pr.set_ylim([0.0, 1.05])
ax_pr.set_xlim([0.0, 1.0])
ax_pr.set_title('Human2 (MLL) → Human3 (Donor): Precision-Recall Curves (Selected Genes)')
ax_pr.legend(loc="lower left")
ax_pr.grid(True, linestyle='--', alpha=0.7)
fig_pr.tight_layout()
fig_pr.savefig(OUT_DIR / "human2_to_human3_selected_pr_curves.png", dpi=300, bbox_inches='tight')

plt.close('all')

# Save summary results
results_df = pd.DataFrame(results)
print("\nFinal results:")
print(results_df)

results_df.to_csv(OUT_DIR / "human2_to_human3_selected_summary_results.csv", index=False)
print(f"\nAll results saved to {OUT_DIR}")

Loading MLL dataset...
Renaming 16 genes using alias dictionary
MLL dataset: 87,171 cells  (♀ 44716  ♂ 42455)

Loading Donor dataset (test data)...
Renaming 16 genes using alias dictionary
Donor dataset: 25,379 cells  (♀ 13508  ♂ 11871)

Extracting 1/15 random subsample from MLL dataset...
Training subsample: 5,811 cells  (♀ 2981  ♂ 2830)
Subsampling ratio: 6.7% of original MLL data

Extracting selected marker genes from training data...
Markers used (9): ['RPS4Y1', 'EIF1AY', 'Xist', 'Ddx3y', 'Uty', 'Kdm5d', 'IFIT3', 'IFIT2', 'RPS4X']

Extracting selected marker genes from test data...
Markers used (9): ['RPS4Y1', 'EIF1AY', 'Xist', 'Ddx3y', 'Uty', 'Kdm5d', 'IFIT3', 'IFIT2', 'RPS4X']

Common markers used for training and testing (9): ['Ddx3y', 'EIF1AY', 'IFIT2', 'IFIT3', 'Kdm5d', 'RPS4X', 'RPS4Y1', 'Uty', 'Xist']

Training and evaluating models using selected genes

=== LogisticRegression ===
 TRAIN → Acc=0.8954, F1=0.8816, AUROC=0.9615, AUPRC=0.9641
 TEST → Acc=0.9764, F1=0.9742, AUROC