In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr, pearsonr
from matplotlib_venn import venn2
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)

# ============================================================================
# 1. LOAD RESULTS FROM BOTH METHODS
# ============================================================================

print("Loading results from both methods...")

# Load Scanpy results - single file with all significant DEGs by cell type
scanpy_all = pd.read_csv('left_vs_right_DEGs_by_celltype.csv')

# Load scVI results - individual cell type files (already filtered for significant DEGs)
scvi_results = {}
cell_types = scanpy_all['Cell_Type'].unique()

for cell_type in cell_types:
    filename = f"DEG_{cell_type.replace(' ', '_')}_left_vs_right.csv"
    try:
        df = pd.read_csv(filename, index_col=0)
        df['Cell_Type'] = cell_type
        df['gene'] = df.index
        scvi_results[cell_type] = df
    except FileNotFoundError:
        print(f"Warning: {filename} not found")

# Combine all scVI results
scvi_all = pd.concat(scvi_results.values(), ignore_index=True)

print(f"\nScanpy significant DEGs: {len(scanpy_all)}")
print(f"scVI significant DEGs: {len(scvi_all)}")
print(f"\nScanpy cell types: {scanpy_all['Cell_Type'].unique()}")
print(f"scVI cell types: {scvi_all['Cell_Type'].unique()}")

# ============================================================================
# 2. STANDARDIZE COLUMN NAMES FOR COMPARISON
# ============================================================================

# Rename columns to have consistent naming
scanpy_comparison = scanpy_all.rename(columns={
    'names': 'gene',
    'logfoldchanges': 'log2FC',
    'pvals_adj': 'padj'
}).copy()

scvi_comparison = scvi_all.rename(columns={
    'lfc_mean': 'log2FC',
    'proba_de': 'posterior_prob'
}).copy()

# Add method identifier
scanpy_comparison['method'] = 'Scanpy'
scvi_comparison['method'] = 'scVI'

print(f"\nScanpy genes: {len(scanpy_comparison)}")
print(f"scVI genes: {len(scvi_comparison)}")

# ============================================================================
# 3. COMPARE SIGNIFICANT GENES BETWEEN METHODS
# ============================================================================

def compare_significant_genes(scanpy_df, scvi_df, cell_type=None):
    """Compare overlap of significant genes between methods"""
    
    if cell_type:
        scanpy_genes = set(scanpy_df[scanpy_df['Cell_Type'] == cell_type]['gene'])
        scvi_genes = set(scvi_df[scvi_df['Cell_Type'] == cell_type]['gene'])
        title_suffix = f" - {cell_type}"
    else:
        scanpy_genes = set(scanpy_df['gene'])
        scvi_genes = set(scvi_df['gene'])
        title_suffix = " - All Cell Types"
    
    # Calculate overlap
    overlap = scanpy_genes & scvi_genes
    scanpy_only = scanpy_genes - scvi_genes
    scvi_only = scvi_genes - scanpy_genes
    
    print(f"\n{'='*60}")
    print(f"Significant Gene Overlap{title_suffix}")
    print(f"{'='*60}")
    print(f"Scanpy significant: {len(scanpy_genes)}")
    print(f"scVI significant: {len(scvi_genes)}")
    
    # Calculate overlap percentage
    total_unique = len(scanpy_genes | scvi_genes)
    if total_unique > 0:
        overlap_pct = len(overlap) / total_unique * 100
        print(f"Overlap: {len(overlap)} ({overlap_pct:.1f}% of total unique genes)")
    else:
        print(f"Overlap: {len(overlap)} (N/A - no genes)")
    
    print(f"Scanpy only: {len(scanpy_only)}")
    print(f"scVI only: {len(scvi_only)}")
    
    return {
        'scanpy_sig': scanpy_genes,
        'scvi_sig': scvi_genes,
        'overlap': overlap,
        'scanpy_only': scanpy_only,
        'scvi_only': scvi_only
    }

# Overall comparison
overall_comparison = compare_significant_genes(scanpy_comparison, scvi_comparison)

# Per cell type comparison
celltype_comparisons = {}
for ct in scanpy_comparison['Cell_Type'].unique():
    celltype_comparisons[ct] = compare_significant_genes(scanpy_comparison, scvi_comparison, ct)

# ============================================================================
# 4. MERGE RESULTS FOR CORRELATION ANALYSIS (OVERLAPPING GENES ONLY)
# ============================================================================

print("\n" + "="*60)
print("Merging overlapping genes for correlation analysis...")
print("="*60)

merged_results = pd.merge(
    scanpy_comparison[['gene', 'Cell_Type', 'log2FC', 'padj']],
    scvi_comparison[['gene', 'Cell_Type', 'log2FC', 'proba_de']],
    on=['gene', 'Cell_Type'],
    suffixes=('_scanpy', '_scvi'),
    how='inner'
)

print(f"Overlapping gene-celltype pairs: {len(merged_results)}")

# ============================================================================
# 5. CORRELATION ANALYSIS OF LOG2FC VALUES
# ============================================================================

print("\n" + "="*60)
print("Log2FC Correlation Analysis (Overlapping Genes)")
print("="*60)

# Remove infinite or NaN values
merged_clean = merged_results[
    np.isfinite(merged_results['log2FC_scanpy']) & 
    np.isfinite(merged_results['log2FC_scvi'])
].copy()

if len(merged_clean) > 0:
    # Calculate correlations
    pearson_r, pearson_p = pearsonr(merged_clean['log2FC_scanpy'], 
                                      merged_clean['log2FC_scvi'])
    spearman_r, spearman_p = spearmanr(merged_clean['log2FC_scanpy'], 
                                         merged_clean['log2FC_scvi'])
    
    print(f"Pearson correlation: {pearson_r:.3f} (p={pearson_p:.2e})")
    print(f"Spearman correlation: {spearman_r:.3f} (p={spearman_p:.2e})")
else:
    pearson_r = spearman_r = 0
    print("No overlapping genes for correlation analysis")

# ============================================================================
# 6. VISUALIZATION
# ============================================================================

# Create figure with multiple subplots
fig = plt.figure(figsize=(18, 12))

# --- Plot 1: Venn Diagram of Overall Overlap ---
ax1 = plt.subplot(2, 3, 1)
if len(overall_comparison['scanpy_sig']) > 0 or len(overall_comparison['scvi_sig']) > 0:
    venn2([overall_comparison['scanpy_sig'], overall_comparison['scvi_sig']], 
          set_labels=('Scanpy', 'scVI'),
          ax=ax1)
    ax1.set_title('Overlap of Significant DEGs\n(All Cell Types - Left vs Right)', fontsize=12, fontweight='bold')
else:
    ax1.text(0.5, 0.5, 'No significant genes', ha='center', va='center', fontsize=12)
    ax1.set_title('Overlap of Significant DEGs\n(All Cell Types - Left vs Right)', fontsize=12, fontweight='bold')
    ax1.axis('off')

# --- Plot 2: Log2FC Correlation Scatter Plot ---
ax2 = plt.subplot(2, 3, 2)
if len(merged_clean) > 0:
    ax2.scatter(merged_clean['log2FC_scanpy'], 
               merged_clean['log2FC_scvi'],
               alpha=0.3, s=10)
    ax2.axhline(0, color='gray', linestyle='--', linewidth=0.5)
    ax2.axvline(0, color='gray', linestyle='--', linewidth=0.5)
    
    # Set axis limits based on data range
    max_val = max(abs(merged_clean['log2FC_scanpy'].max()), 
                  abs(merged_clean['log2FC_scanpy'].min()),
                  abs(merged_clean['log2FC_scvi'].max()),
                  abs(merged_clean['log2FC_scvi'].min()))
    ax2.plot([-max_val, max_val], [-max_val, max_val], 'r--', linewidth=1, label='y=x')
    
    ax2.set_xlabel('Log2FC (Scanpy)', fontsize=11)
    ax2.set_ylabel('Log2FC (scVI)', fontsize=11)
    ax2.set_title(f'Log2FC Correlation\nPearson r={pearson_r:.3f}, Spearman ρ={spearman_r:.3f}', 
                 fontsize=12, fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
else:
    ax2.text(0.5, 0.5, 'No overlapping genes\nfor correlation', ha='center', va='center', fontsize=12)
    ax2.set_title('Log2FC Correlation', fontsize=12, fontweight='bold')
    ax2.axis('off')

# --- Plot 3: Number of DEGs per Cell Type ---
ax3 = plt.subplot(2, 3, 3)
deg_counts = []
for ct in scanpy_comparison['Cell_Type'].unique():
    if ct in celltype_comparisons:
        scanpy_n = len(celltype_comparisons[ct]['scanpy_sig'])
        scvi_n = len(celltype_comparisons[ct]['scvi_sig'])
        overlap_n = len(celltype_comparisons[ct]['overlap'])
        deg_counts.append({
            'Cell_Type': ct,
            'Scanpy': scanpy_n,
            'scVI': scvi_n,
            'Overlap': overlap_n
        })

deg_counts_df = pd.DataFrame(deg_counts).sort_values('Overlap', ascending=False)

x = np.arange(len(deg_counts_df))
width = 0.35

ax3.barh(x - width/2, deg_counts_df['Scanpy'], width, label='Scanpy', alpha=0.8)
ax3.barh(x + width/2, deg_counts_df['scVI'], width, label='scVI', alpha=0.8)
ax3.set_yticks(x)
ax3.set_yticklabels(deg_counts_df['Cell_Type'], fontsize=9)
ax3.set_xlabel('Number of Significant DEGs', fontsize=11)
ax3.set_title('DEGs per Cell Type by Method', fontsize=12, fontweight='bold')
ax3.legend()
ax3.grid(True, alpha=0.3, axis='x')

# --- Plot 4: Agreement Direction (Up/Down regulated) ---
ax4 = plt.subplot(2, 3, 4)

if len(merged_results) > 0:
    # Classify direction agreement for overlapping genes
    merged_results['scanpy_dir'] = np.sign(merged_results['log2FC_scanpy'])
    merged_results['scvi_dir'] = np.sign(merged_results['log2FC_scvi'])
    merged_results['agreement'] = merged_results['scanpy_dir'] == merged_results['scvi_dir']
    
    agreement_counts = merged_results['agreement'].value_counts()
    
    # Ensure we have labels that match the data
    labels = []
    values = []
    colors = []
    
    if True in agreement_counts.index:
        labels.append('Agree')
        values.append(agreement_counts[True])
        colors.append('lightgreen')
    
    if False in agreement_counts.index:
        labels.append('Disagree')
        values.append(agreement_counts[False])
        colors.append('salmon')
    
    if len(values) > 0:
        ax4.pie(values, labels=labels, autopct='%1.1f%%', colors=colors, startangle=90)
        ax4.set_title(f'Direction Agreement\n(n={len(merged_results)} overlapping genes)', 
                     fontsize=12, fontweight='bold')
    else:
        ax4.text(0.5, 0.5, 'No data', ha='center', va='center', fontsize=12)
        ax4.axis('off')
else:
    ax4.text(0.5, 0.5, 'No overlapping genes', ha='center', va='center', fontsize=12)
    ax4.set_title('Direction Agreement\n(n=0 genes)', fontsize=12, fontweight='bold')
    ax4.axis('off')

# --- Plot 5: Effect Size Comparison (overlapping genes) ---
ax5 = plt.subplot(2, 3, 5)

if len(merged_clean) > 0:
    colors_map = merged_results.loc[merged_clean.index, 'agreement'].map({True: 'green', False: 'red'})
    ax5.scatter(np.abs(merged_clean['log2FC_scanpy']), 
               np.abs(merged_clean['log2FC_scvi']),
               alpha=0.4, s=20, c=colors_map)
    
    max_val = max(np.abs(merged_clean['log2FC_scanpy']).max(),
                  np.abs(merged_clean['log2FC_scvi']).max())
    ax5.plot([0, max_val], [0, max_val], 'k--', linewidth=1, alpha=0.5)
    
    ax5.set_xlabel('|Log2FC| (Scanpy)', fontsize=11)
    ax5.set_ylabel('|Log2FC| (scVI)', fontsize=11)
    ax5.set_title('Effect Size Comparison\n(Overlapping Genes)', 
                 fontsize=12, fontweight='bold')
    ax5.grid(True, alpha=0.3)
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='green', alpha=0.4, label='Same direction'),
                       Patch(facecolor='red', alpha=0.4, label='Opposite direction')]
    ax5.legend(handles=legend_elements, fontsize=9)
else:
    ax5.text(0.5, 0.5, 'No overlapping genes', ha='center', va='center', fontsize=12)
    ax5.set_title('Effect Size Comparison\n(n=0 genes)', fontsize=12, fontweight='bold')
    ax5.axis('off')

# --- Plot 6: Overlap percentage by cell type ---
ax6 = plt.subplot(2, 3, 6)

overlap_pcts = []
for ct in scanpy_comparison['Cell_Type'].unique():
    if ct in celltype_comparisons:
        comp = celltype_comparisons[ct]
        total = len(comp['scanpy_sig'] | comp['scvi_sig'])
        overlap_pct = len(comp['overlap']) / total * 100 if total > 0 else 0
        overlap_pcts.append({'Cell_Type': ct, 'Overlap_%': overlap_pct})

overlap_df = pd.DataFrame(overlap_pcts).sort_values('Overlap_%', ascending=True)

if len(overlap_df) > 0:
    ax6.barh(overlap_df['Cell_Type'], overlap_df['Overlap_%'], color='steelblue', alpha=0.7)
    ax6.set_xlabel('Overlap Percentage (%)', fontsize=11)
    ax6.set_title('Method Agreement by Cell Type', fontsize=12, fontweight='bold')
    ax6.grid(True, alpha=0.3, axis='x')
else:
    ax6.text(0.5, 0.5, 'No data', ha='center', va='center', fontsize=12)
    ax6.axis('off')

plt.tight_layout()
plt.savefig('left_vs_right_method_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n✅ Comparison figure saved: left_vs_right_method_comparison.png")

# ============================================================================
# 7. SAVE COMPARISON SUMMARY
# ============================================================================

# Create summary table
summary_rows = []
for ct in scanpy_comparison['Cell_Type'].unique():
    if ct in celltype_comparisons:
        comp = celltype_comparisons[ct]
        
        # Calculate overlap percentage
        total = len(comp['scanpy_sig'] | comp['scvi_sig'])
        overlap_pct = (len(comp['overlap']) / total * 100) if total > 0 else 0
        
        summary_rows.append({
            'Cell_Type': ct,
            'Scanpy_DEGs': len(comp['scanpy_sig']),
            'scVI_DEGs': len(comp['scvi_sig']),
            'Overlap': len(comp['overlap']),
            'Scanpy_only': len(comp['scanpy_only']),
            'scVI_only': len(comp['scvi_only']),
            'Overlap_pct': overlap_pct
        })

summary_table = pd.DataFrame(summary_rows).sort_values('Overlap', ascending=False)
summary_table.to_csv('left_vs_right_method_comparison_summary.csv', index=False)

print("\n✅ Summary table saved: left_vs_right_method_comparison_summary.csv")

# Save detailed comparison of overlapping genes
if len(merged_results) > 0:
    merged_results.to_csv('left_vs_right_overlapping_genes_details.csv', index=False)
    print("✅ Detailed overlapping genes saved: left_vs_right_overlapping_genes_details.csv")

print("\n" + "="*60)
print("COMPARISON COMPLETE")
print("="*60)
print(f"\nSummary:")
print(f"  Total Scanpy DEGs: {len(scanpy_comparison)}")
print(f"  Total scVI DEGs: {len(scvi_comparison)}")
print(f"  Overlapping DEGs: {len(overall_comparison['overlap'])}")
print(f"  Scanpy only: {len(overall_comparison['scanpy_only'])}")
print(f"  scVI only: {len(overall_comparison['scvi_only'])}")

Loading results from both methods...

Scanpy significant DEGs: 5518
scVI significant DEGs: 45

Scanpy cell types: ['CD4+ T Cells' 'CMS3' 'Tip-like ECs' 'CD8+ T cells' 'B Cells' 'Spp1+'
 'Mast cells' 'Stromal 2' 'CMS2' 'Regulatory T Cells' 'Pericytes'
 'Dendritic cells' 'Gamma delta T cells' 'Helper 17 T cells'
 'Mature Enterocytes type 2' 'NK cells' 'Plasma Cells'
 'Plasmacytoid Dendritic Cells' 'Follicular helper T cells']
scVI cell types: ['CD4+ T Cells' 'CMS3' 'Tip-like ECs' 'Spp1+' 'Mast cells' 'Stromal 2'
 'CMS2' 'Pericytes' 'Dendritic cells' 'Gamma delta T cells'
 'Mature Enterocytes type 2' 'Plasmacytoid Dendritic Cells'
 'Follicular helper T cells']

Scanpy genes: 5518
scVI genes: 45

Significant Gene Overlap - All Cell Types
Scanpy significant: 3062
scVI significant: 26
Overlap: 16 (0.5% of total unique genes)
Scanpy only: 3046
scVI only: 10

Significant Gene Overlap - CD4+ T Cells
Scanpy significant: 569
scVI significant: 2
Overlap: 2 (0.4% of total unique genes)
Scanpy only:

KeyError: "['proba_de'] not in index"