In [4]:
# 07_model_interpretation_cleaned.ipynb

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report, accuracy_score
import scanpy as sc
import joblib
import json
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier


In [None]:

# Load data
X_combined = np.load("results/files/X_combined.npy")
y = np.load("results/files/y_labels.npy", allow_pickle=True)
importance_df = pd.read_csv("results//files/feature_importance.csv")

# Load AnnData
adata = sc.read_h5ad("results//files/rna_annotated.h5ad")

# Filter out rare/noisy cell types
drop_labels = ["Mito-rich/Unknown", "pDCs or DC subtype"]
adata_filtered = adata[~adata.obs['cell_type'].isin(drop_labels)].copy()
adata_filtered.obs['cell_type'] = adata_filtered.obs['cell_type'].replace({
    "Monocytes / Neutrophils": "Monocytes"
})

# Encode labels
y_filtered = adata_filtered.obs['cell_type']
le = LabelEncoder()
y_encoded = le.fit_transform(y_filtered)
# Get indices where adata.obs_names are in adata_filtered.obs_names
indices = adata.obs_names.isin(adata_filtered.obs_names)

# Filter X_combined accordingly
X_filtered = X_combined[indices]
# Compute sample weights
sample_weights = compute_sample_weight(class_weight="balanced", y=y_encoded)

# Train-test split
X_train, X_test, y_train, y_test, w_train, w_test = train_test_split(
    X_filtered, y_encoded, sample_weights, test_size=0.2, stratify=y_encoded, random_state=42
)

# Train model
model = XGBClassifier(n_estimators=200, max_depth=6, random_state=42, eval_metric='mlogloss')
model.fit(X_train, y_train, sample_weight=w_train)

# Save model and encoder
joblib.dump(model, "results//files/updated_model.pkl")
joblib.dump(le, "results//files/updated_label_encoder.pkl")

# Predictions
adata_filtered.obs['prediction'] = model.predict(X_filtered)
adata_filtered.obs['confidence'] = model.predict_proba(X_filtered).max(axis=1)
adata_filtered.obs['prediction_label'] = le.inverse_transform(adata_filtered.obs['prediction'])
adata_filtered.obs['true_label'] = y_filtered.values
adata_filtered.obs['incorrect'] = adata_filtered.obs['prediction_label'] != adata_filtered.obs['true_label']

# UMAP Visualization
sc.pl.umap(adata_filtered, color=['true_label', 'prediction_label', 'confidence', 'incorrect'], frameon=False)
fig = plt.gcf()
fig.savefig("results//plots/updated_umap_fill_prediction.png", dpi=300)
plt.close(fig)

# Misclassification summary
mis_summary = adata_filtered.obs.groupby('true_label')['incorrect'].mean().sort_values(ascending=False)
print("Misclassification rate per cell type:")
print(mis_summary.head(10))

# Save misclassification info
adata_filtered.obs[['cell_type', 'prediction', 'prediction_label', 'confidence', 'incorrect']].to_csv("results//files/updated_misclassifications.csv")

# Confidence histogram
plt.figure(figsize=(8, 4))
sns.histplot(adata_filtered.obs['confidence'], bins=30, kde=True)
plt.title("Prediction Confidence Distribution")
plt.xlabel("Max Predicted Probability")
plt.tight_layout()
plt.savefig("results//plots/updated_confidence_distribution.png")
plt.show()

# Confusion Matrix
cm = confusion_matrix(y_encoded, model.predict(X_filtered))
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=le.classes_).plot(xticks_rotation=90)
plt.title("Updated Confusion Matrix")
plt.tight_layout()
plt.savefig("results//plots/updated_confusion_matrix.png")
plt.show()

# Classification report and accuracy
report = classification_report(y_encoded, model.predict(X_filtered), target_names=le.classes_, output_dict=True)
acc = accuracy_score(y_encoded, model.predict(X_filtered))

with open("results//files/updated_classification_report.json", "w") as f:
    json.dump(report, f, indent=2)

with open("results//files/updated_test_accuracy.txt", "w") as f:
    f.write(f"{acc:.4f}")


In [None]:
# Load feature names from CSV
feature_names_df = pd.read_csv("results/files/feature_names.csv", header=None)  # or your actual path
features = feature_names_df.iloc[:, 0].values  # assumes first column holds names


importance_df = pd.DataFrame({
    "Feature": features,
    "Importance": importances
})
importance_df['Modality'] = importance_df['Feature'].apply(lambda x: x.split('_')[-1])

# Top features by modality
top_n = 20
top_rna = importance_df[importance_df['Modality'] == 'RNA'].sort_values('Importance', ascending=False).head(top_n)
top_atac = importance_df[importance_df['Modality'] == 'ATAC'].sort_values('Importance', ascending=False).head(top_n)

# Save top features
top_rna.to_csv("results//files/updated_top_rna_genes.csv", index=False)
top_atac.to_csv("results//files/updated_top_atac_genes.csv", index=False)

# Plot feature importances
plt.figure(figsize=(8, 6))
sns.barplot(x='Importance', y='Feature', data=top_rna)
plt.title(f"Top {top_n} RNA Feature Importances")
plt.tight_layout()
plt.savefig("results//plots/updated_top_rna_importances.png")
plt.show()

plt.figure(figsize=(8, 6))
sns.barplot(x='Importance', y='Feature', data=top_atac)
plt.title(f"Top {top_n} ATAC Feature Importances")
plt.tight_layout()
plt.savefig("results//plots/updated_top_atac_importances.png")
plt.show()
