In [2]:
# ML Classifier: Predict Cell Types from RNA+ATAC

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, StratifiedKFold, RandomizedSearchCV
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay, f1_score
from xgboost import XGBClassifier
from sklearn.preprocessing import LabelEncoder
import joblib
import json
from collections import Counter



In [None]:
# Load data
file_path = "results/files/ag_X_combined_with_labels.csv"
df = pd.read_csv(file_path)
X = df.drop(columns=['cell_type'])
y = df['cell_type']

# Encode labels
le = LabelEncoder()
y_encoded = le.fit_transform(y)
print("Label mapping:", dict(zip(le.classes_, le.transform(le.classes_))))

# Save label encoder
joblib.dump(le, "results/files/ag_label_encoder.pkl")


In [None]:
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, stratify=y_encoded, random_state=42)
print("Train class distribution:", Counter(y_train))
print("Test class distribution:", Counter(y_test))

# Define model
xgb_clf = XGBClassifier(eval_metric='mlogloss', random_state=42)

# Hyperparameter search
param_dist = {
    'n_estimators': [100, 200, 300],
    'max_depth': [3, 5, 7],
    'learning_rate': [0.01, 0.1, 0.2],
    'subsample': [0.8, 1.0],
    'colsample_bytree': [0.8, 1.0]
}

random_search = RandomizedSearchCV(
    estimator=xgb_clf,
    param_distributions=param_dist,
    n_iter=20,
    scoring='accuracy',
    cv=3,
    verbose=1,
    n_jobs=-1,
    random_state=42
)
random_search.fit(X_train, y_train)

print("Best Parameters:", random_search.best_params_)
print("Best CV Accuracy:", random_search.best_score_)

In [None]:
# Evaluate best model on test set
best_model = random_search.best_estimator_
y_pred = best_model.predict(X_test)
y_pred_labels = le.inverse_transform(y_pred)

report_dict = classification_report(y_test, y_pred, target_names=le.classes_, output_dict=True)
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=le.classes_))

# Save model and results
joblib.dump(best_model, "results/files/ag_xgb_best_model.pkl")
np.save("results/files/ag_y_pred.npy", y_pred)
np.save("results/files/ag_y_test.npy", y_test)

with open("results/files/ag_classification_report.json", "w") as f:
    json.dump(report_dict, f, indent=2)


In [None]:
# Train on full data for annotation
best_model_full = random_search.best_estimator_
best_model_full.fit(X, y_encoded)

joblib.dump(best_model_full, "results/files/ag_xgb_model_full_data.pkl")
y_pred_full = best_model_full.predict(X)
np.save("results/files/ag_y_pred_full.npy", y_pred_full)

y_proba_full = best_model_full.predict_proba(X)
np.save("results/files/ag_y_proba_full.npy", y_proba_full)

# Confusion Matrix
plt.figure(figsize=(15, 8))
cm = confusion_matrix(y_test, y_pred, normalize='true')
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=le.classes_)
disp.plot(xticks_rotation=90, cmap='Blues')
plt.title("Normalized Confusion Matrix")
plt.tight_layout()
plt.savefig("results/plots/ag_confusion_matrix.png")
plt.show()


In [None]:
# Class Probability Distributions
y_proba = best_model.predict_proba(X_test)
plt.figure(figsize=(12, 6))
for i in range(len(le.classes_)):
    sns.kdeplot(y_proba[:, i], label=le.classes_[i], fill=True, alpha=0.4)
plt.title("Class Probability Distributions")
plt.xlabel("Predicted Probability")
plt.legend()
plt.tight_layout()
plt.savefig("results/plots/ag_class_probability_distribution.png")
plt.show()

In [10]:
# Feature Importance
importances = best_model.feature_importances_
feature_names = X.columns
importance_df = pd.DataFrame({
    'Feature': feature_names,
    'Importance': importances,
    'Modality': ['RNA' if '_RNA' in f else 'ATAC' for f in feature_names]
}).sort_values(by='Importance', ascending=False)

importance_df.to_csv("results/files/ag_feature_importance.csv", index=False)


In [None]:
# Top 20 Features
plt.figure(figsize=(8, 6))
sns.barplot(x='Importance', y='Feature', data=importance_df.head(20))
plt.title("Top 20 Feature Importances")
plt.tight_layout()
plt.savefig("results/plots/ag_feature_importance_top20.png")
plt.show()

# Modality-level Feature Importance
modality_importance = importance_df.groupby('Modality')['Importance'].sum().reset_index()
plt.figure(figsize=(6, 4))
sns.barplot(data=modality_importance, x='Modality', y='Importance')
plt.title("Total Feature Importance by Modality")
plt.tight_layout()
plt.savefig("results/plots/ag_modality_feature_importance.png")
plt.show()