# Multi-Task Drug Response Prediction

**Goal**: Predict IC50 values for multiple drugs simultaneously

**Why Multi-Task?**
- Single-drug prediction can overfit to drug-specific patterns
- Multi-task forces the model to learn generalizable biological features
- More realistic evaluation: can the model predict drug responses broadly?

**Split Strategies**:
1. **Random split** (Baseline): Standard 80/20 random split
2. **Histology-based split** (Hard): Test on unseen cancer subtypes
3. **Site-based split** (Hard): Test on unseen tissue origins

## Google Colab Setup

In [None]:
import os

IN_COLAB = 'google.colab' in str(get_ipython()) if 'get_ipython' in dir() else False

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    DATA_DIR = '/content/drive/MyDrive/bioai_data'
    os.chdir(DATA_DIR)
    print(f"Working directory: {os.getcwd()}")
    print(f"Files available: {os.listdir('.')}")
else:
    print("Not running in Colab - using local paths")

## Imports

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import RobustScaler
from sklearn.decomposition import PCA
from sklearn.linear_model import Ridge
from sklearn.ensemble import RandomForestRegressor
from sklearn.multioutput import MultiOutputRegressor
from xgboost import XGBRegressor
import warnings
warnings.filterwarnings('ignore')

## Data Loading

In [None]:
# Load multi-omics feature data
methexpr_df = pd.read_csv('ml_with_gene_expr.csv.gz', compression='gzip', index_col=0, low_memory=False)

metadata_cols = ['primary site', 'primary histology', 'cosmic_id']
methylation_cols = [col for col in methexpr_df.columns if col.startswith('cg')]
expression_cols = [col for col in methexpr_df.columns if col.startswith('expr_')]

print(f"Feature data: {methexpr_df.shape}")
print(f"Methylation: {len(methylation_cols)} | Expression: {len(expression_cols)}")

In [None]:
# Load drug response data
response_df = pd.read_csv('ML_dataset_methylation_drug_response.csv.gz', compression='gzip', index_col=0, low_memory=False)

drug_cols = [col for col in response_df.columns if not col.startswith('cg') and col not in metadata_cols]
print(f"Drug response data: {response_df.shape}")
print(f"Total drugs: {len(drug_cols)}")

## Drug Selection by Coverage

In [None]:
# Analyze and select drugs with high coverage
drug_response = response_df[drug_cols]
coverage = drug_response.notna().sum().sort_values(ascending=False)
total_samples = len(response_df)

COVERAGE_THRESHOLD = 0.90
min_samples = int(COVERAGE_THRESHOLD * total_samples)
selected_drugs = coverage[coverage >= min_samples].index.tolist()

print(f"Coverage threshold: {COVERAGE_THRESHOLD*100:.0f}% ({min_samples} samples)")
print(f"Selected drugs: {len(selected_drugs)}")
print(f"\nTop 5: {selected_drugs[:5]}")

## Data Preparation

In [None]:
# Merge and filter data
df = methexpr_df.join(response_df[selected_drugs], how='inner')

# Filter to cells with good drug coverage
Y_all = df[selected_drugs]
cell_coverage = Y_all.notna().sum(axis=1) / len(selected_drugs)
df_filtered = df[cell_coverage >= 0.95].copy()

# Prepare X and Y
feature_cols = methylation_cols + expression_cols
X = df_filtered[feature_cols].dropna(axis=1)  # Drop cols with NaN
Y = df_filtered[selected_drugs].fillna(df_filtered[selected_drugs].median())  # Impute Y

print(f"Final dataset:")
print(f"  X: {X.shape}")
print(f"  Y: {Y.shape}")
print(f"  Dropped {len(feature_cols) - X.shape[1]} features with NaN")

## Evaluation Functions

In [None]:
def evaluate_multitask(model, X_train, Y_train, X_test, Y_test):
    """Evaluate multi-output regression model."""
    Y_pred_train = model.predict(X_train)
    Y_pred_test = model.predict(X_test)
    
    Y_pred_train_df = pd.DataFrame(Y_pred_train, columns=Y_train.columns, index=Y_train.index)
    Y_pred_test_df = pd.DataFrame(Y_pred_test, columns=Y_test.columns, index=Y_test.index)
    
    train_r2, train_mse, test_r2, test_mse = [], [], [], []
    for drug in Y_train.columns:
        train_r2.append(r2_score(Y_train[drug], Y_pred_train_df[drug]))
        train_mse.append(mean_squared_error(Y_train[drug], Y_pred_train_df[drug]))
        test_r2.append(r2_score(Y_test[drug], Y_pred_test_df[drug]))
        test_mse.append(mean_squared_error(Y_test[drug], Y_pred_test_df[drug]))
    
    return {
        'train_r2': np.mean(train_r2),
        'train_mse': np.mean(train_mse),
        'test_r2': np.mean(test_r2),
        'test_mse': np.mean(test_mse),
        'pct_positive': np.mean(np.array(test_r2) > 0) * 100,
        'per_drug_r2': dict(zip(Y_train.columns, test_r2)),
    }

def print_results(results, name):
    print(f"\n{'='*50}")
    print(f"{name}")
    print(f"{'='*50}")
    print(f"Train - R2: {results['train_r2']:.4f} | MSE: {results['train_mse']:.4f}")
    print(f"Test  - R2: {results['test_r2']:.4f} | MSE: {results['test_mse']:.4f}")
    print(f"Drugs with R2>0: {results['pct_positive']:.1f}%")

def create_benchmark():
    return pd.DataFrame(columns=['Model', 'Train R2', 'Train MSE', 'Test R2', 'Test MSE', '% R2>0'])

def add_to_benchmark(df, name, results):
    df.loc[len(df)] = [name, results['train_r2'], results['train_mse'], 
                       results['test_r2'], results['test_mse'], results['pct_positive']]
    return df

## Configuration

In [None]:
N_COMPONENTS = 50
SPLIT_THRESHOLD = 0.40

---
# Split 0: Random (Baseline)

In [None]:
X_train_r, X_test_r, Y_train_r, Y_test_r = train_test_split(X, Y, test_size=0.2, random_state=42)

scaler_r = RobustScaler()
X_train_r_s = scaler_r.fit_transform(X_train_r)
X_test_r_s = scaler_r.transform(X_test_r)

pca_r = PCA(n_components=N_COMPONENTS)
X_train_r_pca = pca_r.fit_transform(X_train_r_s)
X_test_r_pca = pca_r.transform(X_test_r_s)

print(f"Random Split: Train {len(X_train_r)} | Test {len(X_test_r)}")
print(f"PCA variance: {pca_r.explained_variance_ratio_.sum()*100:.1f}%")

In [None]:
bench_r = create_benchmark()

# Ridge
ridge_r = Ridge(alpha=1.0)
ridge_r.fit(X_train_r_pca, Y_train_r)
res_ridge_r = evaluate_multitask(ridge_r, X_train_r_pca, Y_train_r, X_test_r_pca, Y_test_r)
print_results(res_ridge_r, 'Ridge (Random)')
bench_r = add_to_benchmark(bench_r, 'Ridge', res_ridge_r)

# Random Forest
rf_r = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42, n_jobs=-1)
rf_r.fit(X_train_r_pca, Y_train_r)
res_rf_r = evaluate_multitask(rf_r, X_train_r_pca, Y_train_r, X_test_r_pca, Y_test_r)
print_results(res_rf_r, 'Random Forest (Random)')
bench_r = add_to_benchmark(bench_r, 'RandomForest', res_rf_r)

# XGBoost
print(f"\nTraining XGBoost ({len(selected_drugs)} drugs)...")
xgb_r = MultiOutputRegressor(XGBRegressor(n_estimators=100, max_depth=5, learning_rate=0.1, 
                                          random_state=42, tree_method='hist', device='cuda', verbosity=0), n_jobs=1)
xgb_r.fit(X_train_r_pca, Y_train_r)
res_xgb_r = evaluate_multitask(xgb_r, X_train_r_pca, Y_train_r, X_test_r_pca, Y_test_r)
print_results(res_xgb_r, 'XGBoost (Random)')
bench_r = add_to_benchmark(bench_r, 'XGBoost', res_xgb_r)

In [None]:
print("\n" + "="*60)
print("RANDOM SPLIT (Baseline)")
print(f"Train: {len(X_train_r)} | Test: {len(X_test_r)}")
print("="*60)
display(bench_r)

---
# Split 1: Histology-Based (Unseen Cancer Subtypes)

In [None]:
histology = df_filtered['primary histology']
hist_counts = histology.value_counts(ascending=True)
split_idx = int(len(hist_counts) * SPLIT_THRESHOLD)
test_hist = hist_counts.index[:split_idx].tolist()
train_hist = hist_counts.index[split_idx:].tolist()

X_train_h = X[histology.isin(train_hist)]
X_test_h = X[histology.isin(test_hist)]
Y_train_h = Y[histology.isin(train_hist)]
Y_test_h = Y[histology.isin(test_hist)]

scaler_h = RobustScaler()
X_train_h_s = scaler_h.fit_transform(X_train_h)
X_test_h_s = scaler_h.transform(X_test_h)

pca_h = PCA(n_components=N_COMPONENTS)
X_train_h_pca = pca_h.fit_transform(X_train_h_s)
X_test_h_pca = pca_h.transform(X_test_h_s)

print(f"Histology Split: Train {len(X_train_h)} | Test {len(X_test_h)}")
print(f"Test histologies: {test_hist[:5]}...")

In [None]:
bench_h = create_benchmark()

ridge_h = Ridge(alpha=1.0)
ridge_h.fit(X_train_h_pca, Y_train_h)
res_ridge_h = evaluate_multitask(ridge_h, X_train_h_pca, Y_train_h, X_test_h_pca, Y_test_h)
print_results(res_ridge_h, 'Ridge (Histology)')
bench_h = add_to_benchmark(bench_h, 'Ridge', res_ridge_h)

rf_h = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42, n_jobs=-1)
rf_h.fit(X_train_h_pca, Y_train_h)
res_rf_h = evaluate_multitask(rf_h, X_train_h_pca, Y_train_h, X_test_h_pca, Y_test_h)
print_results(res_rf_h, 'Random Forest (Histology)')
bench_h = add_to_benchmark(bench_h, 'RandomForest', res_rf_h)

print(f"\nTraining XGBoost ({len(selected_drugs)} drugs)...")
xgb_h = MultiOutputRegressor(XGBRegressor(n_estimators=100, max_depth=5, learning_rate=0.1,
                                          random_state=42, tree_method='hist', device='cuda', verbosity=0), n_jobs=1)
xgb_h.fit(X_train_h_pca, Y_train_h)
res_xgb_h = evaluate_multitask(xgb_h, X_train_h_pca, Y_train_h, X_test_h_pca, Y_test_h)
print_results(res_xgb_h, 'XGBoost (Histology)')
bench_h = add_to_benchmark(bench_h, 'XGBoost', res_xgb_h)

In [None]:
print("\n" + "="*60)
print("HISTOLOGY SPLIT (Unseen Cancer Subtypes)")
print(f"Train: {len(X_train_h)} | Test: {len(X_test_h)}")
print("="*60)
display(bench_h)

---
# Split 2: Site-Based (Unseen Tissue Origins)

In [None]:
site = df_filtered['primary site']
site_counts = site.value_counts(ascending=True)
site_idx = int(len(site_counts) * SPLIT_THRESHOLD)
test_site = site_counts.index[:site_idx].tolist()
train_site = site_counts.index[site_idx:].tolist()

X_train_s = X[site.isin(train_site)]
X_test_s = X[site.isin(test_site)]
Y_train_s = Y[site.isin(train_site)]
Y_test_s = Y[site.isin(test_site)]

scaler_s = RobustScaler()
X_train_s_s = scaler_s.fit_transform(X_train_s)
X_test_s_s = scaler_s.transform(X_test_s)

pca_s = PCA(n_components=N_COMPONENTS)
X_train_s_pca = pca_s.fit_transform(X_train_s_s)
X_test_s_pca = pca_s.transform(X_test_s_s)

print(f"Site Split: Train {len(X_train_s)} | Test {len(X_test_s)}")
print(f"Test sites: {test_site}")

In [None]:
bench_s = create_benchmark()

ridge_s = Ridge(alpha=1.0)
ridge_s.fit(X_train_s_pca, Y_train_s)
res_ridge_s = evaluate_multitask(ridge_s, X_train_s_pca, Y_train_s, X_test_s_pca, Y_test_s)
print_results(res_ridge_s, 'Ridge (Site)')
bench_s = add_to_benchmark(bench_s, 'Ridge', res_ridge_s)

rf_s = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42, n_jobs=-1)
rf_s.fit(X_train_s_pca, Y_train_s)
res_rf_s = evaluate_multitask(rf_s, X_train_s_pca, Y_train_s, X_test_s_pca, Y_test_s)
print_results(res_rf_s, 'Random Forest (Site)')
bench_s = add_to_benchmark(bench_s, 'RandomForest', res_rf_s)

print(f"\nTraining XGBoost ({len(selected_drugs)} drugs)...")
xgb_s = MultiOutputRegressor(XGBRegressor(n_estimators=100, max_depth=5, learning_rate=0.1,
                                          random_state=42, tree_method='hist', device='cuda', verbosity=0), n_jobs=1)
xgb_s.fit(X_train_s_pca, Y_train_s)
res_xgb_s = evaluate_multitask(xgb_s, X_train_s_pca, Y_train_s, X_test_s_pca, Y_test_s)
print_results(res_xgb_s, 'XGBoost (Site)')
bench_s = add_to_benchmark(bench_s, 'XGBoost', res_xgb_s)

In [None]:
print("\n" + "="*60)
print("SITE SPLIT (Unseen Tissue Origins)")
print(f"Train: {len(X_train_s)} | Test: {len(X_test_s)}")
print("="*60)
display(bench_s)

---
# Final Comparison

In [None]:
print("="*70)
print("MULTI-TASK DRUG RESPONSE: FINAL COMPARISON")
print("="*70)
print(f"Dataset: {len(X)} cells x {len(selected_drugs)} drugs")
print(f"Features: {X.shape[1]} -> {N_COMPONENTS} PCA")

print("\n" + "-"*70)
print(f"RANDOM (Baseline) | Train: {len(X_train_r)} | Test: {len(X_test_r)}")
print("-"*70)
display(bench_r)

print("\n" + "-"*70)
print(f"HISTOLOGY (Hard) | Train: {len(X_train_h)} | Test: {len(X_test_h)}")
print("-"*70)
display(bench_h)

print("\n" + "-"*70)
print(f"SITE (Hard) | Train: {len(X_train_s)} | Test: {len(X_test_s)}")
print("-"*70)
display(bench_s)

In [None]:
print("\n" + "="*70)
print("INTERPRETATION")
print("="*70)
print("""
1. RANDOM (Baseline): Test similar to train -> higher R2 expected
2. HISTOLOGY (Hard): Unseen cancer subtypes -> tests subtype generalization  
3. SITE (Hard): Unseen tissues -> tests tissue generalization

Key metric: '% R2>0' = fraction of drugs predicted better than mean
Negative R2 = model worse than predicting the mean for all samples
""")

## Per-Drug Analysis

In [None]:
rand_r2 = pd.Series(res_xgb_r['per_drug_r2'])
hist_r2 = pd.Series(res_xgb_h['per_drug_r2'])
site_r2 = pd.Series(res_xgb_s['per_drug_r2'])

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].hist(rand_r2, bins=30, alpha=0.7, label='Random', color='green')
axes[0].hist(hist_r2, bins=30, alpha=0.7, label='Histology', color='blue')
axes[0].hist(site_r2, bins=30, alpha=0.7, label='Site', color='orange')
axes[0].axvline(x=0, color='red', linestyle='--')
axes[0].set_xlabel('Test R2')
axes[0].set_title('Per-Drug R2 Distribution')
axes[0].legend()

axes[1].scatter(rand_r2, hist_r2, alpha=0.5, c='blue')
axes[1].axhline(0, color='red', linestyle='--', alpha=0.5)
axes[1].axvline(0, color='red', linestyle='--', alpha=0.5)
axes[1].plot([-0.5, 0.5], [-0.5, 0.5], 'k--', alpha=0.3)
axes[1].set_xlabel('Random R2')
axes[1].set_ylabel('Histology R2')
axes[1].set_title('Random vs Histology')

axes[2].scatter(rand_r2, site_r2, alpha=0.5, c='orange')
axes[2].axhline(0, color='red', linestyle='--', alpha=0.5)
axes[2].axvline(0, color='red', linestyle='--', alpha=0.5)
axes[2].plot([-0.5, 0.5], [-0.5, 0.5], 'k--', alpha=0.3)
axes[2].set_xlabel('Random R2')
axes[2].set_ylabel('Site R2')
axes[2].set_title('Random vs Site')

plt.tight_layout()
plt.show()

print(f"\nR2 drop from Random baseline:")
print(f"  -> Histology: {(rand_r2 - hist_r2).mean():.4f}")
print(f"  -> Site: {(rand_r2 - site_r2).mean():.4f}")
print(f"\nDrugs with R2>0:")
print(f"  Random: {(rand_r2>0).sum()}/{len(rand_r2)} ({100*(rand_r2>0).mean():.1f}%)")
print(f"  Histology: {(hist_r2>0).sum()}/{len(hist_r2)} ({100*(hist_r2>0).mean():.1f}%)")
print(f"  Site: {(site_r2>0).sum()}/{len(site_r2)} ({100*(site_r2>0).mean():.1f}%)")