# CyTOF Label Transfer: Advanced Workflow
## Feature Evaluation, Selection, Custom Hyperparameters, and Prediction

This notebook demonstrates the complete workflow for CyTOF label transfer with the new features:
1. **Feature Evaluation**: Compute and visualize feature importance before training
2. **Feature Selection**: Select features by group or importance threshold
3. **Custom Hyperparameters**: Use custom XGBoost hyperparameter distributions
4. **Model Training**: Train with cross-validated hyperparameter search
5. **Evaluation**: Generate QC plots and metrics
6. **Prediction**: Apply trained model to target timepoint

**Prerequisites**:
- conda environment from `environment.yml`
- `.h5ad` file with timepoints and trusted labels for timepoints 1–4

## 1. Import Required Libraries

In [2]:
# Core imports
import os
import json
from pathlib import Path

# Data handling
import numpy as np
import pandas as pd
import anndata as ad

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# CyTOF Label Transfer
from cytof_label_transfer import (
    load_anndata,
    split_timepoints,
    extract_xy,
    compute_feature_importance,
    create_feature_groups,
    plot_feature_importance,
    select_features_by_groups,
    select_features_by_importance,
    select_features_interactive_report,
    train_classifier,
    load_hyperparameters_from_json,
    predict_timepoint,
)
from cytof_label_transfer.data_utils import extract_x_target
from cytof_label_transfer.qc import evaluate_and_plot_cv
from cytof_label_transfer.model import TrainedModelBundle

# Plotting configuration
%matplotlib inline
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['figure.dpi'] = 100

print('✓ All libraries imported successfully')
print(f'Working directory: {os.getcwd()}')

ImportError: cannot import name 'extract_xy' from 'cytof_label_transfer' (/home/md-adnan-karim/Documents/git_repo/cytofLabelTransfer/cytofPredictionModel/cytof_label_transfer/__init__.py)

## 2. Load and Prepare Data

In [None]:
# Configure paths and parameters
INPUT_H5AD = 'path/to/your_data.h5ad'  # UPDATE THIS PATH
TIMEPOINT_COL = 'timepoint'             # Column name for timepoints
LABEL_COL = 'celltype'                  # Column name for cell type labels
OBSM_KEY = 'X_scVI_200_epoch'           # Optional: latent space key (set to None if not using)
LAYER = None                            # Optional: layer name (None = use .X)

# Output directories
OUTPUT_DIR = Path('results')
FEATURE_EVAL_DIR = OUTPUT_DIR / 'feature_evaluation'
MODEL_DIR = OUTPUT_DIR / 'trained_model'

# Training parameters
TRAIN_TIMEPOINTS = [1, 2, 3, 4]  # Trusted timepoints
TARGET_TIMEPOINT = 5              # Target timepoint for prediction

# Create output directories
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
FEATURE_EVAL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR.mkdir(parents=True, exist_ok=True)

print(f'✓ Configuration set')
print(f'  Input file: {INPUT_H5AD}')
print(f'  Output directory: {OUTPUT_DIR}')
print(f'  Training timepoints: {TRAIN_TIMEPOINTS}')
print(f'  Target timepoint: {TARGET_TIMEPOINT}')

In [None]:
# Load AnnData object
print(f'Loading data from {INPUT_H5AD}...')
adata = load_anndata(INPUT_H5AD)

print(f'\n✓ Data loaded successfully')
print(f'  Shape: {adata.n_obs} cells × {adata.n_vars} genes')
print(f'  Timepoints: {sorted(adata.obs[TIMEPOINT_COL].unique())}')
print(f'  Cell types: {adata.obs[LABEL_COL].nunique()} types')
if OBSM_KEY:
    print(f'  Latent space ({OBSM_KEY}): {adata.obsm[OBSM_KEY].shape[1]} dimensions')

In [None]:
# Split data into training and target
print(f'Splitting data...')
adata_train, adata_target = split_timepoints(
    adata,
    time_col=TIMEPOINT_COL,
    train_timepoints=TRAIN_TIMEPOINTS,
    target_timepoint=TARGET_TIMEPOINT,
)

print(f'\n✓ Data split successfully')
print(f'  Training set: {adata_train.n_obs} cells from timepoints {TRAIN_TIMEPOINTS}')
print(f'  Target set: {adata_target.n_obs} cells from timepoint {TARGET_TIMEPOINT}')

# Display cell type distribution in training set
print(f'\nCell type distribution in training set:')
print(adata_train.obs[LABEL_COL].value_counts())

In [None]:
# Extract features for training
print(f'Extracting features...')
X_train, y_train, feature_names = extract_xy(
    adata_train,
    label_col=LABEL_COL,
    use_layer=LAYER,
    use_obsm_key=OBSM_KEY,
)

print(f'\n✓ Features extracted')
print(f'  Feature matrix shape: {X_train.shape}')
print(f'  Total features: {len(feature_names)}')
if OBSM_KEY:
    n_markers = len(feature_names) - adata.obsm[OBSM_KEY].shape[1]
    print(f'    - Markers: {n_markers}')
    print(f'    - Latent ({OBSM_KEY}): {adata.obsm[OBSM_KEY].shape[1]}')
print(f'  Class distribution: {dict(pd.Series(y_train).value_counts())}')

## 3. Feature Importance Evaluation

In [None]:
# Compute feature importance
print('Computing feature importance (Random Forest)...')
importances, _ = compute_feature_importance(
    X_train,
    y_train,
    feature_names,
    method='random_forest',
    n_estimators=100,
)

print(f'✓ Feature importance computed')
print(f'\nTop 10 most important features:')
top_indices = np.argsort(importances)[::-1][:10]
for rank, idx in enumerate(top_indices, 1):
    print(f'  {rank:2d}. {feature_names[idx]:30s} importance={importances[idx]:.4f}')

In [None]:
# Visualize feature importance
print('Plotting feature importance...')
plot_feature_importance(
    importances,
    feature_names,
    top_n=30,
    output_path=FEATURE_EVAL_DIR / 'feature_importance_top30.png',
    figsize=(12, 8),
)
plt.show()
print(f'✓ Plot saved to {FEATURE_EVAL_DIR / "feature_importance_top30.png"}')

In [None]:
# Generate feature importance report
print('Generating feature importance report...')
if OBSM_KEY:
    feature_groups = create_feature_groups(feature_names, obsm_key=OBSM_KEY)
else:
    feature_groups = {'all_markers': list(range(len(feature_names)))}

select_features_interactive_report(
    importances,
    feature_names,
    feature_groups,
    output_dir=FEATURE_EVAL_DIR,
)
print(f'✓ Report saved to {FEATURE_EVAL_DIR / "feature_importance_report.csv"}')

# Load and display the report
report_df = pd.read_csv(FEATURE_EVAL_DIR / 'feature_importance_report.csv')
print(f'\nFeature Importance Report (top 15):')
print(report_df.head(15).to_string())

## 4. Feature Selection (Optional)

In [None]:
# Option 1: Select by feature groups
# Comment out if you want to use all features or Option 2

FEATURE_SELECTION_MODE = 'all'  # Options: 'all', 'markers_only', 'latent_only', 'top_percentile', 'manual'

selected_feature_indices = None
selected_feature_names = None

if FEATURE_SELECTION_MODE == 'markers_only' and OBSM_KEY:
    print('Selecting markers only...')
    selected_feature_indices, selected_feature_names = select_features_by_groups(
        feature_names,
        feature_groups,
        ['markers'],
    )
    print(f'✓ Selected {len(selected_feature_indices)} marker features')

elif FEATURE_SELECTION_MODE == 'latent_only' and OBSM_KEY:
    print('Selecting latent features only...')
    selected_feature_indices, selected_feature_names = select_features_by_groups(
        feature_names,
        feature_groups,
        ['latent'],
    )
    print(f'✓ Selected {len(selected_feature_indices)} latent features')

elif FEATURE_SELECTION_MODE == 'top_percentile':
    print('Selecting features above 90th percentile...')
    selected_feature_indices, selected_feature_names = select_features_by_importance(
        importances,
        feature_names,
        percentile=90,
    )
    print(f'✓ Selected {len(selected_feature_indices)} features')

elif FEATURE_SELECTION_MODE == 'manual':
    # Manually specify feature indices (0-based)
    print('Using manually selected features...')
    # Example: top 20 features by importance
    top_20_indices = np.argsort(importances)[::-1][:20]
    selected_feature_indices = top_20_indices
    selected_feature_names = [feature_names[i] for i in selected_feature_indices]
    print(f'✓ Selected {len(selected_feature_indices)} features')

else:
    print('Using all features')

if selected_feature_indices is not None:
    print(f'\nSelected features: {selected_feature_names}')

In [None]:
# Re-extract features if selection was applied
if selected_feature_indices is not None:
    print('Re-extracting features with selection...')
    X_train, y_train, feature_names = extract_xy(
        adata_train,
        label_col=LABEL_COL,
        use_layer=LAYER,
        use_obsm_key=OBSM_KEY,
        selected_feature_indices=selected_feature_indices,
    )
    print(f'✓ Features re-extracted')
    print(f'  Original features: {len(feature_names) + len(selected_feature_indices) - len(feature_names)}')
    print(f'  Selected features: {len(feature_names)}')

## 5. Custom Hyperparameters (Optional)

In [None]:
# Option 1: Use default hyperparameters
USE_CUSTOM_HYPERPARAMS = False
CUSTOM_HYPERPARAMS_FILE = 'custom_hyperparams.json'  # Path to JSON file

param_distributions = None

if USE_CUSTOM_HYPERPARAMS:
    print(f'Loading custom hyperparameters from {CUSTOM_HYPERPARAMS_FILE}...')
    param_distributions = load_hyperparameters_from_json(CUSTOM_HYPERPARAMS_FILE)
    print(f'✓ Loaded {len(param_distributions)} hyperparameter settings')
    print(f'\nHyperparameter search space:')
    for param, values in param_distributions.items():
        print(f'  {param}: {values}')
else:
    print('Using default hyperparameter distributions')

## 6. Train the Model

In [None]:
# Training configuration
CV_FOLDS = 5
CV_ITERATIONS = 30  # Number of random hyperparameter configurations to try
USE_GPU = False     # Set to True if you have GPU and XGBoost GPU support

print('='*60)
print('STARTING MODEL TRAINING')
print('='*60)
print(f'Training set: {X_train.shape[0]} cells × {X_train.shape[1]} features')
print(f'Number of classes: {len(np.unique(y_train))}')
print(f'CV folds: {CV_FOLDS}')
print(f'Hyperparameter iterations: {CV_ITERATIONS}')
print(f'Using GPU: {USE_GPU}')
print('='*60)

In [None]:
# Train classifier
bundle = train_classifier(
    X_train,
    y_train,
    feature_names=feature_names,
    n_splits=CV_FOLDS,
    n_iter=CV_ITERATIONS,
    param_distributions=param_distributions,
    output_dir=MODEL_DIR,
    use_gpu=USE_GPU,
)

print(f'\n✓ Model training completed')
print(f'  Best CV F1 score: {bundle.cv_best_score:.4f}')
print(f'  Number of classes: {len(bundle.label_names)}')
print(f'  Number of features: {len(bundle.feature_names)}')

In [None]:
# Display training metrics
metrics_file = MODEL_DIR / 'training_metrics.json'
if metrics_file.exists():
    with open(metrics_file) as f:
        metrics = json.load(f)
    
    print('\nTraining Metrics:')
    print(f'  Cross-validated F1 (macro): {metrics["cv_best_score"]:.4f}')
    print(f'  Training F1 (macro): {metrics["train_f1_macro"]:.4f}')
    print(f'  Training Accuracy: {metrics["train_accuracy"]:.4f}')
    print(f'  Number of training samples: {metrics["n_samples"]}')
    print(f'  Number of features used: {metrics["n_features"]}')
    print(f'\nBest hyperparameters:')
    for param, value in metrics['best_params'].items():
        print(f'  {param}: {value}')

## 7. Evaluate the Model

In [None]:
# Generate QC plots and metrics
print('Generating QC plots and evaluation metrics...')
qc_dir = MODEL_DIR / 'qc'

evaluate_and_plot_cv(
    estimator=bundle.estimator,
    X=X_train,
    y=y_train,
    class_names=bundle.label_names,
    output_dir=qc_dir,
    n_splits=CV_FOLDS,
    label_encoder=bundle.label_encoder,
)

print(f'✓ QC plots saved to {qc_dir}')

In [None]:
# Display confusion matrix plot
from matplotlib.image import imread

cm_path = qc_dir / 'cv_confusion_matrix.png'
if cm_path.exists():
    fig, ax = plt.subplots(figsize=(10, 8))
    img = imread(cm_path)
    ax.imshow(img)
    ax.axis('off')
    plt.title('Cross-validated Confusion Matrix')
    plt.tight_layout()
    plt.show()
    print('Confusion matrix displayed')

In [None]:
# Display per-class F1 scores
f1_path = qc_dir / 'cv_per_class_f1.png'
if f1_path.exists():
    fig, ax = plt.subplots(figsize=(12, 5))
    img = imread(f1_path)
    ax.imshow(img)
    ax.axis('off')
    plt.title('Per-class F1 Scores (Cross-validation)')
    plt.tight_layout()
    plt.show()
    print('Per-class F1 plot displayed')

## 8. Make Predictions on Target Timepoint

In [None]:
# Extract features for target timepoint
print(f'Extracting features for target timepoint {TARGET_TIMEPOINT}...')

X_target, _ = extract_x_target(
    adata_target,
    use_layer=LAYER,
    use_obsm_key=OBSM_KEY,
    selected_feature_indices=selected_feature_indices,
)

print(f'✓ Target features extracted')
print(f'  Shape: {X_target.shape}')
print(f'  Number of cells: {X_target.shape[0]}')

In [None]:
# Make predictions
print(f'Making predictions on {adata_target.n_obs} cells...')

y_pred, y_proba = predict_timepoint(bundle, X_target)

print(f'✓ Predictions completed')
print(f'  Number of predictions: {len(y_pred)}')
print(f'  Predicted classes: {np.unique(y_pred)}')
print(f'\nPrediction distribution:')
print(pd.Series(y_pred).value_counts().sort_index())

In [None]:
# Extract prediction confidence (max probability)
if y_proba is not None:
    max_confidence = y_proba.max(axis=1)
    print(f'\nPrediction Confidence Statistics:')
    print(f'  Mean: {max_confidence.mean():.4f}')
    print(f'  Median: {np.median(max_confidence):.4f}')
    print(f'  Min: {max_confidence.min():.4f}')
    print(f'  Max: {max_confidence.max():.4f}')
    print(f'  Std: {max_confidence.std():.4f}')
    
    # Visualize confidence distribution
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.hist(max_confidence, bins=30, edgecolor='black', alpha=0.7)
    ax.axvline(max_confidence.mean(), color='red', linestyle='--', label=f'Mean: {max_confidence.mean():.3f}')
    ax.set_xlabel('Prediction Confidence')
    ax.set_ylabel('Number of Cells')
    ax.set_title('Distribution of Prediction Confidence Scores')
    ax.legend()
    plt.tight_layout()
    plt.show()

## 9. Write Predictions Back to AnnData

In [None]:
# Add predictions to the full AnnData object
print(f'Writing predictions to AnnData object...')

# Initialize columns
adata.obs['celltype_predicted'] = pd.Series(index=adata.obs_names, dtype='object')
adata.obs['prediction_confidence'] = pd.Series(index=adata.obs_names, dtype='float')

# Fill in predictions for target cells
adata.obs.loc[adata_target.obs_names, 'celltype_predicted'] = y_pred

if y_proba is not None:
    adata.obs.loc[adata_target.obs_names, 'prediction_confidence'] = max_confidence

print(f'✓ Predictions added to adata.obs')
print(f'\nColumns in adata.obs:')
print(adata.obs[['celltype_predicted', 'prediction_confidence']].head(10))

In [None]:
# Save updated AnnData
output_h5ad = OUTPUT_DIR / 'data_with_predictions.h5ad'

print(f'Saving updated AnnData to {output_h5ad}...')
adata.write_h5ad(output_h5ad)

print(f'✓ Saved successfully')
print(f'  File size: {output_h5ad.stat().st_size / 1024 / 1024:.1f} MB')

## 10. Visualize Results

In [None]:
# Comparison of original vs predicted labels (for target timepoint)
if LABEL_COL in adata_target.obs.columns:
    original_labels = adata_target.obs[LABEL_COL]
    
    # Create comparison dataframe
    comparison_df = pd.DataFrame({
        'original': original_labels,
        'predicted': y_pred,
        'confidence': max_confidence if y_proba is not None else 1.0,
    })
    
    print('Sample Predictions vs Original Labels:')
    print(comparison_df.head(20).to_string())
    
    # Count matches
    matches = (original_labels == y_pred).sum()
    total = len(y_pred)
    accuracy = matches / total
    print(f'\nComparison with original labels:')
    print(f'  Matches: {matches}/{total} ({accuracy*100:.1f}%)')
else:
    print('Note: No original labels for target timepoint to compare')

In [None]:
# Summary statistics
print('\n' + '='*60)
print('WORKFLOW SUMMARY')
print('='*60)
print(f'\nInput Data:')
print(f'  Total cells: {adata.n_obs}')
print(f'  Training set: {adata_train.n_obs} cells')
print(f'  Target set: {adata_target.n_obs} cells')
print(f'\nFeatures:')
print(f'  Initial features: {len(feature_names) + (len(selected_feature_indices) if selected_feature_indices is not None else 0)}')
print(f'  Features used: {len(feature_names)}')
print(f'\nModel:')
print(f'  Algorithm: XGBoost')
print(f'  Classes: {len(bundle.label_names)}')
print(f'  CV F1 Score: {bundle.cv_best_score:.4f}')
print(f'\nPredictions:')
print(f'  Predictions made: {len(y_pred)}')
if y_proba is not None:
    print(f'  Mean confidence: {max_confidence.mean():.4f}')
print(f'\nOutput Files:')
print(f'  Model: {MODEL_DIR}')
print(f'  QC plots: {MODEL_DIR / "qc"}')
print(f'  Predictions: {output_h5ad}')
print('='*60)

## Next Steps

1. **Review the QC plots** in the `qc/` directory to assess model quality
2. **Check the predictions** and their confidence scores
3. **Validate results** by comparing with original labels if available
4. **Adjust parameters** if needed:
   - Try different feature selections
   - Experiment with custom hyperparameters
   - Increase/decrease CV iterations
5. **Use the trained model** to predict on new data

For more information, see:
- [ADVANCED_USAGE.md](ADVANCED_USAGE.md) for detailed feature documentation
- [PRACTICAL_EXAMPLES.md](PRACTICAL_EXAMPLES.md) for more code examples
- [README.md](README.md) for installation and basic usage