# External Validation of Resistance Signatures

## Overview
This notebook validates the identified resistance signatures and SpatioResist model predictions using independent external cohorts.

### Objectives
1. Apply resistance signatures to external datasets
2. Validate SpatioResist predictions
3. Assess cross-cancer generalizability
4. Correlate with clinical outcomes

### Validation Strategy
- **Independent scRNA-seq cohorts**: Test resistance gene signatures
- **TCGA bulk RNA-seq**: Large-scale survival analysis
- **Clinical trial data**: Response prediction accuracy

---

In [None]:
import scanpy as sc
import anndata as ad
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml
import warnings

# Statistics and survival
from scipy import stats
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import logrank_test
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve

warnings.filterwarnings('ignore')

# Project paths
PROJECT_ROOT = Path("../..").resolve()
DATA_PROCESSED = PROJECT_ROOT / 'data' / 'processed'
DATA_EXTERNAL = PROJECT_ROOT / 'data' / 'external'
MODELS = PROJECT_ROOT / 'results' / 'models'
FIGURES = PROJECT_ROOT / 'results' / 'figures'
TABLES = PROJECT_ROOT / 'results' / 'tables'
CONFIG_PATH = PROJECT_ROOT / 'config' / 'analysis_params.yaml'

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

SEED = config['random_seed']
np.random.seed(SEED)

## 1. Load Resistance Signatures

In [None]:
# Load identified resistance signatures
# These would be generated from the resistance analysis notebooks

resistance_signatures = config['resistance']['signatures']

print("Resistance signatures:")
for name, genes in resistance_signatures.items():
    print(f"  {name}: {len(genes)} genes")
    print(f"    {genes}")

In [None]:
def score_signature(adata, signature_genes, score_name):
    """
    Score cells/samples using a gene signature.
    
    Parameters
    ----------
    adata : AnnData
        Data matrix
    signature_genes : list
        Gene list for scoring
    score_name : str
        Name for the score in obs
    
    Returns
    -------
    AnnData
        Data with score added
    """
    # Filter to genes present in data
    present_genes = [g for g in signature_genes if g in adata.var_names]
    
    if len(present_genes) > 0:
        sc.tl.score_genes(adata, gene_list=present_genes, score_name=score_name)
        print(f"{score_name}: scored with {len(present_genes)}/{len(signature_genes)} genes")
    else:
        print(f"{score_name}: no genes found in data")
        adata.obs[score_name] = 0
    
    return adata

## 2. Validate on External scRNA-seq Cohort

In [None]:
# Load external validation dataset
# This would be an independent immunotherapy cohort

# Example: Load a validation dataset
validation_path = DATA_EXTERNAL / 'validation_cohort.h5ad'

if validation_path.exists():
    adata_val = sc.read_h5ad(validation_path)
    print(f"Loaded validation cohort: {adata_val.n_obs} cells")
else:
    print(f"Validation data not found at: {validation_path}")
    print("Please add external validation datasets to data/external/")

In [None]:
# Score validation cohort with resistance signatures
# Uncomment when validation data is available:

# for sig_name, sig_genes in resistance_signatures.items():
#     adata_val = score_signature(adata_val, sig_genes, f'{sig_name}_score')

print("Signature scoring code ready")

## 3. TCGA Bulk RNA-seq Validation

Validate signatures using TCGA data for survival analysis.

In [None]:
def load_tcga_data(cancer_type):
    """
    Load TCGA expression and clinical data.
    
    Parameters
    ----------
    cancer_type : str
        TCGA cancer type abbreviation (e.g., 'SKCM', 'LUAD')
    
    Returns
    -------
    pd.DataFrame, pd.DataFrame
        Expression matrix and clinical data
    """
    # In practice, download from TCGA or use TCGAbiolinks
    expr_path = DATA_EXTERNAL / f'TCGA_{cancer_type}_expression.csv'
    clin_path = DATA_EXTERNAL / f'TCGA_{cancer_type}_clinical.csv'
    
    if expr_path.exists() and clin_path.exists():
        expr = pd.read_csv(expr_path, index_col=0)
        clin = pd.read_csv(clin_path, index_col=0)
        return expr, clin
    else:
        print(f"TCGA data not found for {cancer_type}")
        return None, None

print("TCGA loading function defined")

In [None]:
def score_bulk_signature(expr_df, signature_genes):
    """
    Score bulk RNA-seq samples using a gene signature.
    
    Uses mean z-score of signature genes.
    
    Parameters
    ----------
    expr_df : pd.DataFrame
        Expression matrix (genes x samples)
    signature_genes : list
        Gene list for scoring
    
    Returns
    -------
    pd.Series
        Score for each sample
    """
    # Filter to present genes
    present = [g for g in signature_genes if g in expr_df.index]
    
    if len(present) == 0:
        return pd.Series(0, index=expr_df.columns)
    
    # Z-score normalization
    expr_subset = expr_df.loc[present]
    z_scores = (expr_subset - expr_subset.mean()) / expr_subset.std()
    
    # Mean z-score
    return z_scores.mean(axis=0)

print("Bulk scoring function defined")

## 4. Survival Analysis

In [None]:
def perform_survival_analysis(clinical_df, scores, score_name, time_col='OS_time', event_col='OS_event'):
    """
    Perform Kaplan-Meier survival analysis.
    
    Parameters
    ----------
    clinical_df : pd.DataFrame
        Clinical data with survival information
    scores : pd.Series
        Signature scores per sample
    score_name : str
        Name for the score
    
    Returns
    -------
    dict
        Results including p-value and hazard ratio
    """
    # Align data
    common_samples = clinical_df.index.intersection(scores.index)
    
    df = clinical_df.loc[common_samples].copy()
    df['score'] = scores.loc[common_samples]
    
    # Remove missing values
    df = df.dropna(subset=[time_col, event_col, 'score'])
    
    if len(df) < 20:
        print(f"Insufficient samples for survival analysis: {len(df)}")
        return None
    
    # Stratify by median
    median_score = df['score'].median()
    df['score_group'] = (df['score'] > median_score).map({True: 'High', False: 'Low'})
    
    # Kaplan-Meier analysis
    fig, ax = plt.subplots(figsize=(8, 6))
    
    kmf = KaplanMeierFitter()
    
    for group in ['Low', 'High']:
        mask = df['score_group'] == group
        kmf.fit(
            df.loc[mask, time_col],
            df.loc[mask, event_col],
            label=f'{group} {score_name}'
        )
        kmf.plot_survival_function(ax=ax)
    
    # Log-rank test
    high_mask = df['score_group'] == 'High'
    lr_result = logrank_test(
        df.loc[high_mask, time_col],
        df.loc[~high_mask, time_col],
        df.loc[high_mask, event_col],
        df.loc[~high_mask, event_col]
    )
    
    pvalue = lr_result.p_value
    
    ax.set_title(f'{score_name} Survival Analysis\nLog-rank p = {pvalue:.4f}')
    ax.set_xlabel('Time (days)')
    ax.set_ylabel('Survival Probability')
    
    plt.tight_layout()
    
    # Cox regression
    cph = CoxPHFitter()
    cox_df = df[[time_col, event_col, 'score']].copy()
    cph.fit(cox_df, duration_col=time_col, event_col=event_col)
    
    hr = np.exp(cph.hazard_ratios_['score'])
    
    return {
        'p_value': pvalue,
        'hazard_ratio': hr,
        'n_samples': len(df),
        'figure': fig
    }

print("Survival analysis function defined")

## 5. Response Prediction Evaluation

In [None]:
def evaluate_response_prediction(y_true, y_pred_proba):
    """
    Evaluate response prediction performance.
    
    Parameters
    ----------
    y_true : array-like
        True response labels (0=responder, 1=non-responder)
    y_pred_proba : array-like
        Predicted probabilities
    
    Returns
    -------
    dict
        Evaluation metrics
    """
    # ROC-AUC
    auc = roc_auc_score(y_true, y_pred_proba)
    
    # ROC curve
    fpr, tpr, thresholds = roc_curve(y_true, y_pred_proba)
    
    # Precision-Recall curve
    precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # ROC
    axes[0].plot(fpr, tpr, label=f'AUC = {auc:.3f}')
    axes[0].plot([0, 1], [0, 1], 'k--')
    axes[0].set_xlabel('False Positive Rate')
    axes[0].set_ylabel('True Positive Rate')
    axes[0].set_title('ROC Curve')
    axes[0].legend()
    
    # PR
    axes[1].plot(recall, precision)
    axes[1].set_xlabel('Recall')
    axes[1].set_ylabel('Precision')
    axes[1].set_title('Precision-Recall Curve')
    
    plt.tight_layout()
    
    return {
        'auc': auc,
        'figure': fig
    }

print("Response prediction evaluation function defined")

## 6. Cross-Cancer Validation

In [None]:
# Validate signatures across multiple cancer types
cancer_types = ['SKCM', 'LUAD', 'STAD', 'LIHC']  # Melanoma, Lung, Gastric, Liver

validation_results = []

for cancer in cancer_types:
    print(f"\nValidating in {cancer}...")
    
    expr, clin = load_tcga_data(cancer)
    
    if expr is not None:
        for sig_name, sig_genes in resistance_signatures.items():
            scores = score_bulk_signature(expr, sig_genes)
            
            # Note: Uncomment when TCGA data is available
            # result = perform_survival_analysis(clin, scores, sig_name)
            # if result:
            #     validation_results.append({
            #         'cancer_type': cancer,
            #         'signature': sig_name,
            #         'p_value': result['p_value'],
            #         'hazard_ratio': result['hazard_ratio'],
            #         'n_samples': result['n_samples']
            #     })
            
            print(f"  {sig_name}: scored")

print("\nCross-cancer validation code ready")

## 7. Save Validation Results

In [None]:
# Save validation results
if validation_results:
    results_df = pd.DataFrame(validation_results)
    results_df.to_csv(TABLES / 'validation_results.csv', index=False)
    print(f"Saved validation results to {TABLES}")
    display(results_df)
else:
    print("No validation results to save. Add external data to perform validation.")

## Summary

### Validation Framework
- Signature scoring for scRNA-seq and bulk RNA-seq
- Kaplan-Meier survival analysis
- Cox regression for hazard ratios
- ROC-AUC for response prediction

### Required External Data
1. Independent immunotherapy scRNA-seq cohorts
2. TCGA expression and clinical data
3. Clinical trial response data

### Next Steps
1. Acquire external validation datasets
2. Run survival analysis in `08b_survival_analysis.ipynb`
3. Export final atlas for publication