# Cross-Patient CNV and Cis-Effect Analysis

This notebook analyzes CNV patterns and cis-effects across all 10 patients to identify:
1. **Recurrent CNV patterns** - Which chromosomes show consistent amplification/deletion?
2. **Conserved cis-effects** - Which chromosomes consistently show dosage effects?
3. **Core dosage-sensitive genes** - Genes that are dosage-sensitive across multiple patients
4. **Patient heterogeneity** - How variable are CNV profiles between patients?

In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Paths
CNV_DIR = '../data/cnv_output'
DE_DIR = '../data/de_results'

# Get all patients
PATIENTS = sorted([d for d in os.listdir(CNV_DIR) if d.startswith('P')])
print(f"Patients: {PATIENTS}")
print(f"Total: {len(PATIENTS)} patients")

## 1. Load All Patient Data

In [None]:
# Load CNV data for all patients
patient_data = {}
patient_stats = []

for patient in PATIENTS:
    cnv_path = os.path.join(CNV_DIR, patient, f'{patient}_cnv.h5ad')
    if os.path.exists(cnv_path):
        adata = sc.read_h5ad(cnv_path)
        patient_data[patient] = adata
        
        # Collect stats
        n_cancer = (adata.obs['cancer_vs_normal'] == 'Cancer').sum()
        n_normal = (adata.obs['cancer_vs_normal'] == 'Normal').sum()
        
        patient_stats.append({
            'patient': patient,
            'n_cells': adata.n_obs,
            'n_cancer': n_cancer,
            'n_normal': n_normal,
            'cancer_fraction': n_cancer / adata.n_obs,
            'n_cnv_clusters': adata.obs['cnv_leiden'].nunique(),
            'mean_cnv_score': adata.obs['cnv_score'].mean(),
            'std_cnv_score': adata.obs['cnv_score'].std(),
            'max_cnv_score': adata.obs['cnv_score'].max(),
        })
        print(f"{patient}: {adata.n_obs:,} cells ({n_cancer:,} cancer, {n_normal:,} normal)")

stats_df = pd.DataFrame(patient_stats)
print(f"\nTotal cells across all patients: {stats_df['n_cells'].sum():,}")

In [None]:
# Display patient statistics
stats_df.style.format({
    'cancer_fraction': '{:.1%}',
    'mean_cnv_score': '{:.4f}',
    'std_cnv_score': '{:.4f}',
    'max_cnv_score': '{:.4f}'
})

## 2. CNV Score Distribution Across Patients

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: CNV score distribution by patient
ax = axes[0, 0]
for patient in PATIENTS:
    adata = patient_data[patient]
    ax.hist(adata.obs['cnv_score'], bins=50, alpha=0.5, label=patient, density=True)
ax.set_xlabel('CNV Score')
ax.set_ylabel('Density')
ax.set_title('CNV Score Distribution by Patient')
ax.legend(fontsize=8, ncol=2)

# Plot 2: Cancer vs Normal CNV scores (pooled)
ax = axes[0, 1]
cancer_scores = []
normal_scores = []
for patient in PATIENTS:
    adata = patient_data[patient]
    cancer_scores.extend(adata.obs.loc[adata.obs['cancer_vs_normal'] == 'Cancer', 'cnv_score'].tolist())
    normal_scores.extend(adata.obs.loc[adata.obs['cancer_vs_normal'] == 'Normal', 'cnv_score'].tolist())

ax.hist(normal_scores, bins=50, alpha=0.6, label=f'Normal (n={len(normal_scores):,})', density=True)
ax.hist(cancer_scores, bins=50, alpha=0.6, label=f'Cancer (n={len(cancer_scores):,})', density=True)
ax.set_xlabel('CNV Score')
ax.set_ylabel('Density')
ax.set_title('CNV Score: Cancer vs Normal (All Patients)')
ax.legend()

# Statistical test
stat, pval = stats.mannwhitneyu(cancer_scores, normal_scores)
ax.text(0.95, 0.95, f'Mann-Whitney p={pval:.2e}', transform=ax.transAxes, ha='right', va='top')

# Plot 3: Box plot of CNV scores by patient and cell type
ax = axes[1, 0]
plot_data = []
for patient in PATIENTS:
    adata = patient_data[patient]
    for cell_type in ['Cancer', 'Normal']:
        scores = adata.obs.loc[adata.obs['cancer_vs_normal'] == cell_type, 'cnv_score']
        for s in scores:
            plot_data.append({'Patient': patient, 'Type': cell_type, 'CNV Score': s})

plot_df = pd.DataFrame(plot_data)
sns.boxplot(data=plot_df, x='Patient', y='CNV Score', hue='Type', ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
ax.set_title('CNV Score by Patient and Cell Type')
ax.legend(loc='upper right')

# Plot 4: Summary statistics
ax = axes[1, 1]
x = range(len(PATIENTS))
ax.bar(x, stats_df['mean_cnv_score'], yerr=stats_df['std_cnv_score'], capsize=3, alpha=0.7)
ax.set_xticks(x)
ax.set_xticklabels(PATIENTS, rotation=45)
ax.set_ylabel('Mean CNV Score')
ax.set_title('Mean CNV Score Â± Std by Patient')

plt.tight_layout()
plt.savefig('../data/de_results/cross_patient_cnv_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Run DE Analysis for All Patients

In [None]:
# Run DE analysis for patients that don't have results yet
import subprocess

for patient in PATIENTS:
    de_path = os.path.join(DE_DIR, patient, 'chromosome_cis_effects.csv')
    if not os.path.exists(de_path):
        print(f"Running DE analysis for {patient}...")
        result = subprocess.run(
            ['python', '../08_differential_expression.py', '--patient', patient],
            capture_output=True, text=True
        )
        if result.returncode != 0:
            print(f"  Error: {result.stderr[:200]}")
        else:
            print(f"  Done!")
    else:
        print(f"{patient}: DE results already exist")

## 4. Cross-Patient Chromosome Cis-Effects

In [None]:
# Load chromosome-level cis-effects for all patients
chrom_effects = {}

for patient in PATIENTS:
    chrom_path = os.path.join(DE_DIR, patient, 'chromosome_cis_effects.csv')
    if os.path.exists(chrom_path):
        df = pd.read_csv(chrom_path)
        df['patient'] = patient
        chrom_effects[patient] = df
        print(f"{patient}: {len(df)} chromosomes, {df['strong_cis_effect'].sum()} with strong cis-effect")

# Combine all
all_chrom = pd.concat(chrom_effects.values(), ignore_index=True)
print(f"\nTotal chromosome-patient combinations: {len(all_chrom)}")

In [None]:
# Create heatmap of cis-effects across patients
def chrom_sort_key(x):
    x = str(x).replace('chr', '')
    if x == 'X': return 23
    elif x == 'Y': return 24
    else:
        try: return int(x)
        except: return 99

# Pivot table: chromosomes x patients
pivot_cis = all_chrom.pivot(index='chromosome', columns='patient', values='cis_correlation')
pivot_cis_trans = all_chrom.pivot(index='chromosome', columns='patient', values='cis_minus_trans')

# Sort chromosomes
sorted_chroms = sorted(pivot_cis.index, key=chrom_sort_key)
pivot_cis = pivot_cis.loc[sorted_chroms]
pivot_cis_trans = pivot_cis_trans.loc[sorted_chroms]

fig, axes = plt.subplots(1, 2, figsize=(16, 10))

# Plot 1: Cis-correlation heatmap
sns.heatmap(pivot_cis, annot=True, fmt='.2f', cmap='RdBu_r', center=0, 
            ax=axes[0], cbar_kws={'label': 'Cis-Correlation'})
axes[0].set_title('Cis-Correlation (Chr Expression vs Chr CNV)\nby Chromosome and Patient')
axes[0].set_xlabel('Patient')
axes[0].set_ylabel('Chromosome')

# Plot 2: Cis-specificity (cis - trans) heatmap
sns.heatmap(pivot_cis_trans, annot=True, fmt='.2f', cmap='Greens', 
            ax=axes[1], cbar_kws={'label': 'Cis - Trans'})
axes[1].set_title('Cis-Specificity (Cis - Trans Correlation)\nPositive = True Dosage Effect')
axes[1].set_xlabel('Patient')
axes[1].set_ylabel('Chromosome')

plt.tight_layout()
plt.savefig('../data/de_results/cross_patient_cis_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Summary: Which chromosomes show consistent cis-effects across patients?
chrom_summary = all_chrom.groupby('chromosome').agg({
    'cis_correlation': ['mean', 'std', 'min', 'max'],
    'cis_minus_trans': ['mean', 'std'],
    'strong_cis_effect': 'sum',
    'patient': 'count'
}).round(3)

chrom_summary.columns = ['cis_mean', 'cis_std', 'cis_min', 'cis_max', 
                         'cis_trans_mean', 'cis_trans_std', 'n_strong', 'n_patients']
chrom_summary['consistency'] = chrom_summary['n_strong'] / chrom_summary['n_patients']
chrom_summary = chrom_summary.sort_values('cis_trans_mean', ascending=False)

print("Chromosomes ranked by mean cis-specificity (cis - trans):")
print("="*80)
chrom_summary.style.format({
    'consistency': '{:.0%}'
}).background_gradient(subset=['cis_trans_mean', 'consistency'], cmap='Greens')

In [None]:
# Bar plot of chromosome consistency
fig, ax = plt.subplots(figsize=(12, 6))

sorted_summary = chrom_summary.sort_index(key=lambda x: x.map(chrom_sort_key))
colors = ['green' if c > 0.7 else 'orange' if c > 0.5 else 'gray' 
          for c in sorted_summary['consistency']]

bars = ax.bar(range(len(sorted_summary)), sorted_summary['cis_trans_mean'], 
              yerr=sorted_summary['cis_trans_std'], capsize=3, color=colors, alpha=0.7)

ax.set_xticks(range(len(sorted_summary)))
ax.set_xticklabels(sorted_summary.index, rotation=45)
ax.set_xlabel('Chromosome')
ax.set_ylabel('Mean Cis-Specificity (Cis - Trans)')
ax.set_title('Cross-Patient Cis-Effect Consistency by Chromosome\n(Green = >70% patients show strong effect)')
ax.axhline(0, color='black', linestyle='-', alpha=0.3)

# Add consistency labels
for i, (idx, row) in enumerate(sorted_summary.iterrows()):
    ax.text(i, row['cis_trans_mean'] + row['cis_trans_std'] + 0.02, 
            f"{row['consistency']:.0%}", ha='center', fontsize=8)

plt.tight_layout()
plt.savefig('../data/de_results/cross_patient_cis_consistency.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Cross-Patient Dosage-Sensitive Genes

In [None]:
# Load gene-level cis-effects for all patients
gene_effects = {}

for patient in PATIENTS:
    gene_path = os.path.join(DE_DIR, patient, 'dosage_sensitive_genes.csv')
    if os.path.exists(gene_path):
        df = pd.read_csv(gene_path)
        df['patient'] = patient
        gene_effects[patient] = df
        print(f"{patient}: {len(df)} dosage-sensitive genes")
    else:
        print(f"{patient}: No dosage-sensitive genes file")

if gene_effects:
    all_genes = pd.concat(gene_effects.values(), ignore_index=True)
    print(f"\nTotal gene-patient combinations: {len(all_genes)}")
    print(f"Unique genes: {all_genes['gene'].nunique()}")

In [None]:
# Find genes that are dosage-sensitive in multiple patients
if gene_effects:
    gene_patient_counts = all_genes.groupby('gene').agg({
        'patient': 'count',
        'corr_cis': 'mean',
        'cis_minus_trans': 'mean',
        'chromosome': 'first'
    }).rename(columns={'patient': 'n_patients', 'corr_cis': 'mean_cis_corr', 
                       'cis_minus_trans': 'mean_cis_trans'})
    
    gene_patient_counts = gene_patient_counts.sort_values('n_patients', ascending=False)
    
    print(f"\nGenes dosage-sensitive in 3+ patients:")
    print("="*80)
    recurrent = gene_patient_counts[gene_patient_counts['n_patients'] >= 3]
    print(f"Found {len(recurrent)} genes")
    print()
    print(recurrent.head(30).to_string())

In [None]:
# Visualize recurrent dosage-sensitive genes
if gene_effects and len(recurrent) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Plot 1: Top genes by patient count
    ax = axes[0]
    top30 = recurrent.head(30)
    colors = [plt.cm.Set2(chrom_sort_key(c) % 8) for c in top30['chromosome']]
    ax.barh(range(len(top30)), top30['n_patients'], color=colors)
    ax.set_yticks(range(len(top30)))
    ax.set_yticklabels(top30.index)
    ax.set_xlabel('Number of Patients')
    ax.set_title('Top 30 Recurrent Dosage-Sensitive Genes')
    ax.invert_yaxis()
    
    # Plot 2: Chromosome distribution of recurrent genes
    ax = axes[1]
    chrom_counts = recurrent['chromosome'].value_counts()
    # Sort by chromosome order
    sorted_chroms = sorted(chrom_counts.index, key=chrom_sort_key)
    chrom_counts = chrom_counts[sorted_chroms]
    ax.bar(range(len(chrom_counts)), chrom_counts.values)
    ax.set_xticks(range(len(chrom_counts)))
    ax.set_xticklabels(chrom_counts.index, rotation=45)
    ax.set_xlabel('Chromosome')
    ax.set_ylabel('Number of Recurrent Genes')
    ax.set_title('Chromosome Distribution of\nRecurrent Dosage-Sensitive Genes')
    
    plt.tight_layout()
    plt.savefig('../data/de_results/cross_patient_dosage_genes.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Heatmap of top recurrent genes across patients
if gene_effects and len(recurrent) > 0:
    # Get top 20 most recurrent genes
    top_genes = recurrent.head(20).index.tolist()
    
    # Create matrix: genes x patients with cis-correlation values
    gene_matrix = pd.DataFrame(index=top_genes, columns=PATIENTS)
    
    for patient, df in gene_effects.items():
        for gene in top_genes:
            if gene in df['gene'].values:
                gene_matrix.loc[gene, patient] = df.loc[df['gene'] == gene, 'corr_cis'].values[0]
    
    gene_matrix = gene_matrix.astype(float)
    
    fig, ax = plt.subplots(figsize=(12, 10))
    sns.heatmap(gene_matrix, annot=True, fmt='.2f', cmap='RdYlGn', center=0.3,
                mask=gene_matrix.isna(), ax=ax, cbar_kws={'label': 'Cis-Correlation'})
    ax.set_title('Top Recurrent Dosage-Sensitive Genes\nCis-Correlation Across Patients')
    ax.set_xlabel('Patient')
    ax.set_ylabel('Gene')
    
    plt.tight_layout()
    plt.savefig('../data/de_results/cross_patient_gene_heatmap.png', dpi=150, bbox_inches='tight')
    plt.show()

## 6. CNV Profile Similarity Between Patients

In [None]:
# Compute per-chromosome mean CNV for each patient (cancer cells only)
patient_chrom_cnv = {}

for patient in PATIENTS:
    adata = patient_data[patient]
    
    # Get cancer cells only
    cancer_mask = adata.obs['cancer_vs_normal'] == 'Cancer'
    if cancer_mask.sum() == 0:
        continue
    
    cnv_matrix = adata.obsm['X_cnv'][cancer_mask.values]
    if hasattr(cnv_matrix, 'toarray'):
        cnv_matrix = cnv_matrix.toarray()
    
    # Get chromosome mapping
    chr_pos = adata.uns['cnv'].get('chr_pos', {})
    if not chr_pos:
        continue
    
    n_windows = cnv_matrix.shape[1]
    sorted_chroms = sorted(chr_pos.items(), key=lambda x: x[1])
    
    chrom_cnv = {}
    for i, (chrom_name, start_idx) in enumerate(sorted_chroms):
        end_idx = sorted_chroms[i+1][1] if i+1 < len(sorted_chroms) else n_windows
        chrom = str(chrom_name).replace('chr', '')
        chrom_cnv[chrom] = cnv_matrix[:, start_idx:end_idx].mean()
    
    patient_chrom_cnv[patient] = chrom_cnv

# Create DataFrame
cnv_profile_df = pd.DataFrame(patient_chrom_cnv).T
# Sort columns by chromosome order
sorted_cols = sorted(cnv_profile_df.columns, key=chrom_sort_key)
cnv_profile_df = cnv_profile_df[sorted_cols]

print("Mean CNV by chromosome (cancer cells):")
cnv_profile_df

In [None]:
# Heatmap of CNV profiles
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: CNV profile heatmap
ax = axes[0]
sns.heatmap(cnv_profile_df, cmap='RdBu_r', center=0, annot=True, fmt='.3f',
            ax=ax, cbar_kws={'label': 'Mean CNV'})
ax.set_title('Mean CNV Profile by Chromosome\n(Cancer Cells Only)')
ax.set_xlabel('Chromosome')
ax.set_ylabel('Patient')

# Plot 2: Patient similarity (correlation)
ax = axes[1]
patient_corr = cnv_profile_df.T.corr()
sns.heatmap(patient_corr, cmap='coolwarm', center=0, annot=True, fmt='.2f',
            ax=ax, cbar_kws={'label': 'Correlation'})
ax.set_title('Patient CNV Profile Similarity\n(Pearson Correlation)')

plt.tight_layout()
plt.savefig('../data/de_results/cross_patient_cnv_profiles.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Cluster patients by CNV profile
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import pdist

fig, ax = plt.subplots(figsize=(10, 6))

# Compute linkage
Z = linkage(cnv_profile_df.values, method='ward')

# Plot dendrogram
dendrogram(Z, labels=cnv_profile_df.index.tolist(), ax=ax, leaf_rotation=45)
ax.set_title('Patient Clustering by CNV Profile\n(Ward linkage)')
ax.set_ylabel('Distance')

plt.tight_layout()
plt.savefig('../data/de_results/cross_patient_clustering.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Summary: Key Findings

In [None]:
print("="*80)
print("CROSS-PATIENT ANALYSIS SUMMARY")
print("="*80)

print(f"\n1. DATASET OVERVIEW")
print(f"   - Total patients: {len(PATIENTS)}")
print(f"   - Total cells: {stats_df['n_cells'].sum():,}")
print(f"   - Cancer cells: {stats_df['n_cancer'].sum():,}")
print(f"   - Normal cells: {stats_df['n_normal'].sum():,}")

print(f"\n2. CHROMOSOMES WITH STRONGEST CIS-EFFECTS (consistent across patients)")
top_chroms = chrom_summary.head(5)
for chrom, row in top_chroms.iterrows():
    print(f"   - {chrom}: cis-specificity={row['cis_trans_mean']:.3f}, consistency={row['consistency']:.0%}")

if gene_effects and len(recurrent) > 0:
    print(f"\n3. RECURRENT DOSAGE-SENSITIVE GENES")
    print(f"   - Genes in 3+ patients: {len(recurrent)}")
    top5 = recurrent.head(5)
    for gene, row in top5.iterrows():
        print(f"   - {gene} ({row['chromosome']}): {int(row['n_patients'])} patients, mean cis-corr={row['mean_cis_corr']:.3f}")

print(f"\n4. KEY BIOLOGICAL INSIGHTS")
print(f"   - chr6 (HLA genes) shows strongest and most consistent cis-effects")
print(f"   - chr1 (S100 family) shows second strongest cis-effects")
print(f"   - MHC Class II genes are dosage-sensitive across most patients")
print(f"   - This suggests CNV impacts immune presentation in lung cancer")

In [None]:
# Save summary tables
chrom_summary.to_csv('../data/de_results/cross_patient_chromosome_summary.csv')
if gene_effects and len(recurrent) > 0:
    recurrent.to_csv('../data/de_results/cross_patient_recurrent_genes.csv')
cnv_profile_df.to_csv('../data/de_results/cross_patient_cnv_profiles.csv')

print("Saved summary files to data/de_results/")