# SpatioResist: Spatial Resistance Prediction Algorithm

## Overview
This notebook develops the **SpatioResist** algorithm - a machine learning framework to predict immunotherapy resistance from spatial transcriptomics patterns.

### Novelty
SpatioResist integrates:
1. **Cell type spatial distributions** from deconvolution
2. **Cell-cell communication networks** from ligand-receptor analysis
3. **Spatial statistics** (neighborhood enrichment, Moran's I)
4. **Gene expression signatures** of resistance

### Output
- Per-sample resistance risk score
- Identification of resistance-driving spatial niches
- Actionable spatial biomarkers

---

In [None]:
import scanpy as sc
import squidpy as sq
import anndata as ad
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml
import warnings

# ML libraries
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, classification_report
import joblib

warnings.filterwarnings('ignore')

# Project paths
PROJECT_ROOT = Path("../..").resolve()
DATA_PROCESSED = PROJECT_ROOT / 'data' / 'processed'
MODELS = PROJECT_ROOT / 'results' / 'models'
FIGURES = PROJECT_ROOT / 'results' / 'figures'
TABLES = PROJECT_ROOT / 'results' / 'tables'
CONFIG_PATH = PROJECT_ROOT / 'config' / 'analysis_params.yaml'

# Add custom module
import sys
sys.path.append(str(PROJECT_ROOT / 'src'))

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

SEED = config['random_seed']
np.random.seed(SEED)

## 1. Feature Engineering

Extract features from spatial transcriptomics that characterize the TME architecture.

In [None]:
class SpatialFeatureExtractor:
    """
    Extract spatial features for resistance prediction.
    """
    
    def __init__(self, config):
        self.config = config
        
    def extract_abundance_features(self, adata):
        """
        Extract cell type abundance statistics.
        """
        features = {}
        
        # Get abundance columns
        abundance_cols = [c for c in adata.obs.columns if c.startswith('abundance_')]
        
        for col in abundance_cols:
            ct = col.replace('abundance_', '')
            values = adata.obs[col].values
            
            # Basic statistics
            features[f'{ct}_mean'] = np.mean(values)
            features[f'{ct}_std'] = np.std(values)
            features[f'{ct}_max'] = np.max(values)
            features[f'{ct}_q75'] = np.percentile(values, 75)
            
            # Spatial concentration
            features[f'{ct}_gini'] = self._gini_coefficient(values)
        
        return features
    
    def extract_spatial_stats(self, adata):
        """
        Extract spatial statistics features.
        """
        features = {}
        
        # Neighborhood enrichment results
        if 'nhood_enrichment' in adata.uns:
            nhood = adata.uns['nhood_enrichment']['zscore']
            
            # Extract key interactions
            cell_types = nhood.index
            for i, ct1 in enumerate(cell_types):
                for j, ct2 in enumerate(cell_types):
                    if i <= j:  # Upper triangle
                        features[f'nhood_{ct1}_{ct2}'] = nhood.iloc[i, j]
        
        # Spatial autocorrelation (Moran's I)
        if 'moranI' in adata.uns:
            for gene in ['exhaustion_score', 'resistance_score']:
                if gene in adata.uns['moranI'].index:
                    features[f'moran_{gene}'] = adata.uns['moranI'].loc[gene, 'I']
        
        return features
    
    def extract_niche_features(self, adata):
        """
        Extract features from spatial niches.
        """
        features = {}
        
        if 'spatial_niche' in adata.obs.columns:
            niche_counts = adata.obs['spatial_niche'].value_counts(normalize=True)
            
            for niche, prop in niche_counts.items():
                features[f'niche_{niche}_proportion'] = prop
        
        return features
    
    def extract_all_features(self, adata):
        """
        Extract all spatial features.
        """
        features = {}
        
        # Combine all feature types
        features.update(self.extract_abundance_features(adata))
        features.update(self.extract_spatial_stats(adata))
        features.update(self.extract_niche_features(adata))
        
        return features
    
    @staticmethod
    def _gini_coefficient(values):
        """Calculate Gini coefficient for spatial concentration."""
        values = np.sort(values)
        n = len(values)
        index = np.arange(1, n + 1)
        return (2 * np.sum(index * values) / (n * np.sum(values))) - (n + 1) / n


print("SpatialFeatureExtractor defined")

## 2. SpatioResist Model

In [None]:
class SpatioResist:
    """
    SpatioResist: Spatial Resistance Prediction Model.
    
    Predicts immunotherapy resistance from spatial transcriptomics features.
    """
    
    def __init__(self, config, model_type='rf'):
        """
        Initialize SpatioResist.
        
        Parameters
        ----------
        config : dict
            Configuration dictionary
        model_type : str
            Model type: 'rf' (Random Forest), 'gb' (Gradient Boosting), 'lr' (Logistic Regression)
        """
        self.config = config
        self.feature_extractor = SpatialFeatureExtractor(config)
        self.scaler = StandardScaler()
        self.feature_names = None
        
        # Initialize model
        if model_type == 'rf':
            self.model = RandomForestClassifier(
                n_estimators=100,
                max_depth=10,
                random_state=config['random_seed'],
                n_jobs=-1
            )
        elif model_type == 'gb':
            self.model = GradientBoostingClassifier(
                n_estimators=100,
                max_depth=5,
                random_state=config['random_seed']
            )
        else:
            self.model = LogisticRegression(
                random_state=config['random_seed'],
                max_iter=1000
            )
    
    def extract_features(self, adata_list):
        """
        Extract features from multiple spatial samples.
        
        Parameters
        ----------
        adata_list : list of AnnData
            List of spatial AnnData objects
        
        Returns
        -------
        pd.DataFrame
            Feature matrix
        """
        feature_dicts = []
        
        for adata in adata_list:
            features = self.feature_extractor.extract_all_features(adata)
            feature_dicts.append(features)
        
        df = pd.DataFrame(feature_dicts)
        self.feature_names = df.columns.tolist()
        
        return df
    
    def fit(self, X, y):
        """
        Train the resistance prediction model.
        
        Parameters
        ----------
        X : pd.DataFrame
            Feature matrix
        y : array-like
            Response labels (0=responder, 1=non-responder)
        """
        # Handle missing values
        X = X.fillna(0)
        
        # Scale features
        X_scaled = self.scaler.fit_transform(X)
        
        # Train model
        self.model.fit(X_scaled, y)
        
        return self
    
    def predict(self, X):
        """
        Predict resistance.
        """
        X = X.fillna(0)
        X_scaled = self.scaler.transform(X)
        return self.model.predict(X_scaled)
    
    def predict_proba(self, X):
        """
        Predict resistance probability.
        """
        X = X.fillna(0)
        X_scaled = self.scaler.transform(X)
        return self.model.predict_proba(X_scaled)[:, 1]
    
    def get_feature_importance(self):
        """
        Get feature importance scores.
        """
        if hasattr(self.model, 'feature_importances_'):
            importance = self.model.feature_importances_
        elif hasattr(self.model, 'coef_'):
            importance = np.abs(self.model.coef_[0])
        else:
            return None
        
        return pd.Series(importance, index=self.feature_names).sort_values(ascending=False)
    
    def save(self, path):
        """Save model to disk."""
        joblib.dump({
            'model': self.model,
            'scaler': self.scaler,
            'feature_names': self.feature_names
        }, path)
    
    @classmethod
    def load(cls, path, config):
        """Load model from disk."""
        data = joblib.load(path)
        instance = cls(config)
        instance.model = data['model']
        instance.scaler = data['scaler']
        instance.feature_names = data['feature_names']
        return instance


print("SpatioResist model defined")

## 3. Train SpatioResist Model

Train the model using spatial data with response labels.

In [None]:
# Load spatial data with response labels
# In practice, you would load multiple samples with known response outcomes

# Example: Load deconvolved spatial data
spatial_files = list((DATA_PROCESSED / 'spatial').glob('*_deconvolved.h5ad'))

if spatial_files:
    adata_list = [sc.read_h5ad(f) for f in spatial_files]
    print(f"Loaded {len(adata_list)} spatial samples")
else:
    print("No deconvolved spatial data found.")
    print("Please run cell2location deconvolution first.")

In [None]:
# Example: Create mock response labels for demonstration
# In practice, these would come from clinical data

# Uncomment when you have actual data:
# response_labels = pd.read_csv(DATA_PROCESSED / 'clinical_response.csv')

# Initialize model
spatioresist = SpatioResist(config, model_type='rf')
print("SpatioResist model initialized")

In [None]:
# Extract features
if spatial_files:
    X = spatioresist.extract_features(adata_list)
    print(f"Extracted {X.shape[1]} features from {X.shape[0]} samples")
    print(f"\nFeature groups:")
    print(f"  Abundance features: {len([c for c in X.columns if 'mean' in c or 'std' in c])}")
    print(f"  Spatial stats: {len([c for c in X.columns if 'nhood' in c or 'moran' in c])}")
    print(f"  Niche features: {len([c for c in X.columns if 'niche' in c])}")

In [None]:
# Cross-validation (example with mock labels)
# Uncomment when you have real response labels:

# y = response_labels['response'].values  # 0=responder, 1=non-responder
# 
# cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
# scores = cross_val_score(spatioresist.model, X.fillna(0), y, cv=cv, scoring='roc_auc')
# 
# print(f"Cross-validation AUC: {scores.mean():.3f} +/- {scores.std():.3f}")

print("Model training code ready - add response labels to train")

## 4. Feature Importance Analysis

In [None]:
# Get feature importance (after training)
# importance = spatioresist.get_feature_importance()
# 
# # Plot top features
# plt.figure(figsize=(10, 8))
# importance.head(20).plot(kind='barh')
# plt.xlabel('Feature Importance')
# plt.title('Top 20 Spatial Features for Resistance Prediction')
# plt.tight_layout()
# plt.savefig(FIGURES / 'spatioresist_feature_importance.png', dpi=150)
# plt.show()

print("Feature importance analysis ready")

## 5. Save Model

In [None]:
# Save trained model
# model_path = MODELS / 'spatioresist_model.joblib'
# spatioresist.save(model_path)
# print(f"Model saved to: {model_path}")

print("Model saving code ready")

## 6. Generate Resistance Risk Scores

In [None]:
def score_spatial_sample(adata, model):
    """
    Generate resistance risk score for a spatial sample.
    
    Parameters
    ----------
    adata : AnnData
        Deconvolved spatial data
    model : SpatioResist
        Trained SpatioResist model
    
    Returns
    -------
    float
        Resistance risk score (0-1)
    dict
        Breakdown of contributing features
    """
    # Extract features
    features = model.feature_extractor.extract_all_features(adata)
    X = pd.DataFrame([features])
    
    # Get prediction
    risk_score = model.predict_proba(X)[0]
    
    # Get feature contributions
    importance = model.get_feature_importance()
    
    top_contributors = {}
    if importance is not None:
        for feat in importance.head(5).index:
            if feat in features:
                top_contributors[feat] = {
                    'value': features[feat],
                    'importance': importance[feat]
                }
    
    return risk_score, top_contributors

print("Scoring function defined")

## Summary

### SpatioResist Algorithm
- Integrates spatial cell type abundances, neighborhood statistics, and niche compositions
- Predicts resistance probability per sample
- Identifies key spatial features driving resistance

### Key Outputs
1. Per-sample resistance risk score
2. Feature importance ranking
3. Spatial biomarker candidates

### Next Steps
1. Validate on external cohorts in `08_validation/`
2. Correlate with survival outcomes
3. Identify druggable spatial targets