# 03. Phase Classifier
## IT/IA/AR/AC ÏßàÌôò Îã®Í≥Ñ ÏòàÏ∏° Î™®Îç∏

**Î™©Ï†Å:**
- IT signature + pathway features Í∏∞Î∞ò Î∂ÑÎ•ò
- Cross-validationÏúºÎ°ú ÏÑ±Îä• Í≤ÄÏ¶ù
- Feature importance Î∂ÑÏÑù

**ÏûÑÏÉÅÏ†Å Ïú†Ïö©ÏÑ±:**
- ÏÉàÎ°úÏö¥ ÌôòÏûêÏùò ÏßàÌôò Îã®Í≥Ñ ÏòàÏ∏°
- ÏπòÎ£å Î∞òÏùë ÏòàÏ∏° Í∞ÄÎä•ÏÑ±

---

## 1. Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import sys
PROJECT_ROOT = '/content/drive/MyDrive/ITLAS'
sys.path.insert(0, PROJECT_ROOT)

import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
from sklearn.ensemble import RandomForestClassifier

import warnings
warnings.filterwarnings('ignore')

print("‚úì Setup complete")

In [None]:
# XGBoost ÏÑ§Ïπò (ÏÑ†ÌÉùÏÇ¨Ìï≠)
try:
    from xgboost import XGBClassifier
    HAS_XGB = True
    print("‚úì XGBoost available")
except ImportError:
    !pip install xgboost -q
    from xgboost import XGBClassifier
    HAS_XGB = True
    print("‚úì XGBoost installed")

## 2. Load Data

In [None]:
# PathwayÍ∞Ä Ìè¨Ìï®Îêú Îç∞Ïù¥ÌÑ∞ Î°úÎìú
data_path = f"{PROJECT_ROOT}/data/processed/GSE182159_with_pathways.h5ad"

try:
    adata = sc.read_h5ad(data_path)
    print(f"‚úì Loaded: {adata.shape}")
except FileNotFoundError:
    # Fallback: IT scoresÎßå ÏûàÎäî ÌååÏùº
    data_path = f"{PROJECT_ROOT}/data/processed/GSE182159_with_IT_scores.h5ad"
    try:
        adata = sc.read_h5ad(data_path)
        print(f"‚úì Loaded (IT scores only): {adata.shape}")
        print("‚ö† Run 02_FM_GSEA.ipynb for pathway features")
    except FileNotFoundError:
        print("‚ö† Run 01 and 02 notebooks first!")

In [None]:
# Stage Î∂ÑÌè¨ ÌôïÏù∏
print("\nStage distribution:")
print(adata.obs['Stage'].value_counts())

# Feature columns ÌôïÏù∏
it_cols = [c for c in adata.obs.columns if c.startswith('IT_')]
pw_cols = [c for c in adata.obs.columns if c.startswith('PW_')]
print(f"\nIT features: {len(it_cols)}")
print(f"Pathway features: {len(pw_cols)}")

## 3. Feature Extraction

In [None]:
# Feature matrix Íµ¨ÏÑ±
feature_cols = it_cols + pw_cols

# IT_likeÎäî categoricalÏù¥ÎØÄÎ°ú Ï†úÏô∏
feature_cols = [c for c in feature_cols if c != 'IT_like']

print(f"Total features: {len(feature_cols)}")
print(f"Features: {feature_cols}")

In [None]:
# X, y Ï§ÄÎπÑ
X = adata.obs[feature_cols].values
y = adata.obs['Stage'].values

# Label encoding
le = LabelEncoder()
y_encoded = le.fit_transform(y)

print(f"X shape: {X.shape}")
print(f"Classes: {le.classes_}")

# Missing values Ï≤òÎ¶¨
if np.isnan(X).any():
    print("‚ö† NaN values detected, filling with 0")
    X = np.nan_to_num(X, nan=0.0)

In [None]:
# Train/Test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y_encoded, test_size=0.2, stratify=y_encoded, random_state=42
)

# Scaling
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

print(f"Train: {X_train.shape}, Test: {X_test.shape}")

## 4. Model Training & Evaluation

In [None]:
# Multiple models ÎπÑÍµê
from sklearn.linear_model import LogisticRegression

models = {
    'Logistic Regression': LogisticRegression(max_iter=1000, random_state=42),
    'Random Forest': RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42),
}

if HAS_XGB:
    models['XGBoost'] = XGBClassifier(
        n_estimators=100, max_depth=5, learning_rate=0.1,
        random_state=42, eval_metric='mlogloss'
    )

results = []

for name, model in models.items():
    print(f"\n{'='*50}")
    print(f"Training: {name}")
    print('='*50)
    
    # Train
    model.fit(X_train_scaled, y_train)
    
    # Predict
    y_pred = model.predict(X_test_scaled)
    
    # Metrics
    acc = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred, average='weighted')
    
    results.append({
        'Model': name,
        'Accuracy': acc,
        'F1_weighted': f1
    })
    
    print(f"Accuracy: {acc:.4f}")
    print(f"F1 (weighted): {f1:.4f}")
    print(f"\nClassification Report:")
    print(classification_report(y_test, y_pred, target_names=le.classes_))

# Results summary
results_df = pd.DataFrame(results).sort_values('F1_weighted', ascending=False)
print("\n" + "="*50)
print("MODEL COMPARISON")
print("="*50)
display(results_df)

## 5. Cross-Validation

In [None]:
# Best modelÎ°ú cross-validation
best_model_name = results_df.iloc[0]['Model']
best_model = models[best_model_name]

print(f"Cross-validation with: {best_model_name}")
print("="*50)

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Accuracy
cv_acc = cross_val_score(best_model, X_train_scaled, y_train, cv=cv, scoring='accuracy')
print(f"\nAccuracy: {cv_acc.mean():.4f} ¬± {cv_acc.std():.4f}")
print(f"  Folds: {cv_acc}")

# F1
cv_f1 = cross_val_score(best_model, X_train_scaled, y_train, cv=cv, scoring='f1_weighted')
print(f"\nF1 (weighted): {cv_f1.mean():.4f} ¬± {cv_f1.std():.4f}")
print(f"  Folds: {cv_f1}")

## 6. Confusion Matrix

In [None]:
# Final model ÌõàÎ†® (Ï†ÑÏ≤¥ train data)
final_model = models[best_model_name]
final_model.fit(X_train_scaled, y_train)
y_pred_final = final_model.predict(X_test_scaled)

# Confusion matrix
cm = confusion_matrix(y_test, y_pred_final)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=le.classes_, yticklabels=le.classes_, ax=axes[0])
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')
axes[0].set_title('Confusion Matrix (Counts)')

# Normalized
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=le.classes_, yticklabels=le.classes_, ax=axes[1])
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('True')
axes[1].set_title('Confusion Matrix (Normalized)')

plt.suptitle(f'{best_model_name} - Phase Classification', fontsize=14, y=1.02)
plt.tight_layout()
fig.savefig(f"{PROJECT_ROOT}/results/figures/confusion_matrix.png", dpi=150, bbox_inches='tight')
print(f"‚úì Saved: results/figures/confusion_matrix.png")
plt.show()

## 7. Feature Importance

In [None]:
# Feature importance Ï∂îÏ∂ú
if hasattr(final_model, 'feature_importances_'):
    importance = final_model.feature_importances_
elif hasattr(final_model, 'coef_'):
    importance = np.abs(final_model.coef_).mean(axis=0)
else:
    importance = None

if importance is not None:
    importance_df = pd.DataFrame({
        'Feature': feature_cols,
        'Importance': importance
    }).sort_values('Importance', ascending=False)
    
    print("\nTop 10 Most Important Features:")
    print("="*50)
    display(importance_df.head(10))
    
    # Visualization
    fig, ax = plt.subplots(figsize=(10, 6))
    top_n = 15
    top_features = importance_df.head(top_n)
    
    colors = ['#e74c3c' if 'IT_' in f else '#3498db' for f in top_features['Feature']]
    ax.barh(range(len(top_features)), top_features['Importance'], color=colors)
    ax.set_yticks(range(len(top_features)))
    ax.set_yticklabels(top_features['Feature'])
    ax.invert_yaxis()
    ax.set_xlabel('Importance')
    ax.set_title(f'Top {top_n} Features for Phase Classification\n(Red=IT signature, Blue=Pathway)')
    
    plt.tight_layout()
    fig.savefig(f"{PROJECT_ROOT}/results/figures/feature_importance.png", dpi=150)
    print(f"\n‚úì Saved: results/figures/feature_importance.png")
    plt.show()
    
    # Save
    importance_df.to_csv(f"{PROJECT_ROOT}/results/tables/feature_importance.csv", index=False)

## 8. IT vs non-IT Binary Classification

In [None]:
# IT vs Others (binary classification)
y_binary = np.where(y == 'IT', 1, 0)

X_train_b, X_test_b, y_train_b, y_test_b = train_test_split(
    X, y_binary, test_size=0.2, stratify=y_binary, random_state=42
)

X_train_b_scaled = scaler.fit_transform(X_train_b)
X_test_b_scaled = scaler.transform(X_test_b)

# Train binary classifier
if HAS_XGB:
    binary_model = XGBClassifier(n_estimators=100, max_depth=5, random_state=42, eval_metric='logloss')
else:
    binary_model = RandomForestClassifier(n_estimators=100, random_state=42)

binary_model.fit(X_train_b_scaled, y_train_b)
y_pred_b = binary_model.predict(X_test_b_scaled)

print("\nIT vs non-IT Binary Classification:")
print("="*50)
print(classification_report(y_test_b, y_pred_b, target_names=['non-IT', 'IT']))

# ROC curve
from sklearn.metrics import roc_curve, auc

y_proba = binary_model.predict_proba(X_test_b_scaled)[:, 1]
fpr, tpr, _ = roc_curve(y_test_b, y_proba)
roc_auc = auc(fpr, tpr)

fig, ax = plt.subplots(figsize=(7, 6))
ax.plot(fpr, tpr, color='#e74c3c', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
ax.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('IT vs non-IT Classification\nROC Curve')
ax.legend(loc='lower right')

plt.tight_layout()
fig.savefig(f"{PROJECT_ROOT}/results/figures/IT_binary_ROC.png", dpi=150)
print(f"‚úì Saved: results/figures/IT_binary_ROC.png")
plt.show()

## 9. Save Model

In [None]:
import pickle

# Save models
model_path = f"{PROJECT_ROOT}/results/models"
import os
os.makedirs(model_path, exist_ok=True)

# Multi-class model
with open(f"{model_path}/phase_classifier.pkl", 'wb') as f:
    pickle.dump({
        'model': final_model,
        'scaler': scaler,
        'label_encoder': le,
        'feature_cols': feature_cols
    }, f)
print(f"‚úì Saved: results/models/phase_classifier.pkl")

# Binary model
with open(f"{model_path}/it_binary_classifier.pkl", 'wb') as f:
    pickle.dump({
        'model': binary_model,
        'scaler': scaler,
        'feature_cols': feature_cols
    }, f)
print(f"‚úì Saved: results/models/it_binary_classifier.pkl")

## 10. Prediction Function

In [None]:
def predict_phase(adata_new, model_path=f"{PROJECT_ROOT}/results/models/phase_classifier.pkl"):
    """Predict disease phase for new data.
    
    Parameters
    ----------
    adata_new : sc.AnnData
        New data with IT signature and pathway scores
        
    Returns
    -------
    predictions : np.ndarray
        Predicted phases
    """
    import pickle
    
    # Load model
    with open(model_path, 'rb') as f:
        saved = pickle.load(f)
    
    model = saved['model']
    scaler = saved['scaler']
    le = saved['label_encoder']
    feature_cols = saved['feature_cols']
    
    # Check features
    missing = [c for c in feature_cols if c not in adata_new.obs.columns]
    if missing:
        raise ValueError(f"Missing features: {missing}")
    
    # Extract features
    X_new = adata_new.obs[feature_cols].values
    X_new = np.nan_to_num(X_new, nan=0.0)
    X_new_scaled = scaler.transform(X_new)
    
    # Predict
    y_pred = model.predict(X_new_scaled)
    predictions = le.inverse_transform(y_pred)
    
    return predictions

print("‚úì predict_phase() function defined")
print("\nUsage:")
print("  predictions = predict_phase(new_adata)")

---
## Summary

### Model Performance:
- **Multi-class (5 stages)**: Accuracy, F1 score
- **Binary (IT vs non-IT)**: AUC-ROC

### Key Features:
- IT signature scores (NK collapse, Mito-high, B cell block)
- Pathway activities (mTOR, glycolysis, OXPHOS)

### Saved Models:
- `phase_classifier.pkl`: Multi-class
- `it_binary_classifier.pkl`: IT vs non-IT

### ÏûÑÏÉÅÏ†Å ÏùòÏùò:
- ÏÉàÎ°úÏö¥ ÌôòÏûê scRNA-seqÏóêÏÑú ÏßàÌôò Îã®Í≥Ñ ÏòàÏ∏° Í∞ÄÎä•
- IT phase ÌôòÏûê Ï°∞Í∏∞ ÏãùÎ≥Ñ

---
## ITLAS Í∞úÎ∞ú ÏôÑÎ£å! üéâ

### Îã§Ïùå Îã®Í≥Ñ:
1. **ÎÖºÎ¨∏ ÏûëÏÑ±**: IT-immunopathogenesis paper
2. **GitHub Ìå®ÌÇ§ÏßÄÌôî**: `pip install itlas`
3. **Ï∂îÍ∞Ä validation**: ÎèÖÎ¶Ω cohort Í≤ÄÏ¶ù