In [1]:
import scanpy as sc
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scvi
import torch
from scipy import stats
from adjustText import adjust_text

In [2]:
import warnings
warnings.simplefilter("ignore", FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)

In [3]:
adata = sc.read('annotated.h5ad')
adata.obs

Unnamed: 0,samples,condition,location,n_genes,n_genes_by_counts,log1p_n_genes_by_counts,total_counts,log1p_total_counts,pct_counts_in_top_20_genes,pct_counts_mt,pct_counts_ribo,pct_counts_hb,low_label,low_score,_scvi_batch,_scvi_labels,overcluster,low_major,Cell_Type
711_AAACCCAAGTCGGGAT-1,711,Tumor,Right,707,707,6.562444,1396.0,7.242083,26.862464,10.100286,25.716331,0.000000,T follicular helper cells,0.103311,0,0,14,CD4+ T cells,CD4+ T Cells
711_AAACCCACAGAGGAAA-1,711,Tumor,Right,838,838,6.732211,1504.0,7.316548,23.803191,9.441490,13.962767,0.000000,Unknown,0.236022,0,0,11,CD4+ T cells,CD4+ T Cells
711_AAACCCACATGATAGA-1,711,Tumor,Right,435,435,6.077642,613.0,6.419995,17.781403,0.815661,18.270800,0.000000,Unknown,0.400460,0,0,36,CMS3,CMS3
711_AAACCCAGTCTCGCGA-1,711,Tumor,Right,579,579,6.363028,860.0,6.758094,23.255814,15.465117,6.395349,0.000000,Tip-like ECs,0.966648,0,0,28,Tip-like ECs,Tip-like ECs
711_AAACGAAGTTATCTTC-1,711,Tumor,Right,1384,1384,7.233455,2629.0,7.874739,19.246862,8.178015,15.405098,0.000000,gamma delta T cells,0.993837,0,0,15,CD8+ T cells,CD8+ T cells
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
706_TTTGGTTCAAACACCT-1,706,Tumor,Left,1815,1815,7.504392,4944.0,8.506132,21.318770,1.961974,21.318771,0.020227,CD8+ T cells,0.999997,0,0,1,CD8+ T cells,CD8+ T cells
706_TTTGGTTCAACGGGTA-1,706,Tumor,Left,909,909,6.813445,3528.0,8.168770,34.722222,0.368481,25.368483,0.000000,Pro-inflammatory,0.549964,0,0,12,SPP1+,Spp1+
706_TTTGGTTTCTATCGCC-1,706,Tumor,Left,4658,4658,8.446556,23600.0,10.069044,17.559322,3.237288,18.135593,0.004237,Stromal 2,0.961222,0,0,25,Stromal 2,Stromal 2
706_TTTGTTGCATCAGCGC-1,706,Tumor,Left,374,374,5.926926,686.0,6.532334,23.032070,10.932944,33.236153,0.000000,CMS2,0.978673,0,0,10,CMS2,CMS2


In [5]:
model = scvi.model.SCVI.load('the_model/', adata)
model

[34mINFO    [0m File the_model/model.pt already downloaded                                                                




In [6]:
# Filter to only tumor samples (assuming you want to compare left vs right tumors)
adata_tumor = adata[adata.obs['condition'] == 'Tumor'].copy()

print(f"\nTotal tumor samples: {adata_tumor.n_obs}")
print(f"Cell types: {adata_tumor.obs['Cell_Type'].nunique()}")

# Check tumor location distribution
if 'tumor_location' in adata_tumor.obs.columns:
    location_col = 'tumor_location'
elif 'Tumor_Location' in adata_tumor.obs.columns:
    location_col = 'Tumor_Location'
elif 'location' in adata_tumor.obs.columns:
    location_col = 'location'
else:
    print("\n⚠️  ERROR: Could not find tumor location column!")
    print("Available columns:", adata_tumor.obs.columns.tolist())
    # Uncomment the line below with the correct column name
    # location_col = 'YOUR_LOCATION_COLUMN_NAME'

print(f"\nTumor location distribution:")
print(adata_tumor.obs[location_col].value_counts())

# Run DEG analysis per cell type
de_results_location = {}
cell_types = adata_tumor.obs['Cell_Type'].unique()

print(f"\nAnalyzing {len(cell_types)} cell types...\n")

for cell_type in cell_types:
    print(f"{'='*60}")
    print(f"Analyzing: {cell_type}")
    print(f"{'='*60}")
    
    # Subset to this cell type
    adata_subset = adata_tumor[adata_tumor.obs['Cell_Type'] == cell_type].copy()
    
    # Check if we have both left and right samples
    location_counts = adata_subset.obs[location_col].value_counts()
    print(f"Left samples: {location_counts.get('Left', 0)}")
    print(f"Right samples: {location_counts.get('Right', 0)}")
    
    if len(location_counts) < 2:
        print(f"⚠️  Skipping {cell_type}: only one location present\n")
        continue
    
    # Check for both 'Left' and 'Right' specifically
    if 'Left' not in location_counts or 'Right' not in location_counts:
        print(f"⚠️  Skipping {cell_type}: missing Left or Right samples\n")
        continue
    
    if location_counts.min() < 10:  # minimum cell threshold
        print(f"⚠️  Warning: Low cell count in one group for {cell_type}")
        print(f"   Proceeding anyway, but results may be unreliable\n")
    
    try:
        # Run differential expression: Left vs Right
        de_df = model.differential_expression(
            adata_subset,
            groupby=location_col,
            group1="Left",
            group2="Right",
            delta=0.5,  # minimum log fold change threshold
            fdr_target=0.05,
            mode='change',
            pseudocounts=1e-7,
            test_mode='two'
        )
        de_results_location[cell_type] = de_df
        
        # Print summary
        n_sig = (de_df['is_de_fdr_0.05']).sum()
        print(f"Significant DEGs (FDR < 0.05): {n_sig}")
        print(f"Total genes tested: {len(de_df)}\n")
        
    except Exception as e:
        print(f"⚠️  Error analyzing {cell_type}: {str(e)}\n")
        continue

print("\n" + "="*70)
print("FILTERING SIGNIFICANT DEGs")
print("="*70)

# Filter results
filtered_de_results_location = {}

for cell_type, de_df in de_results_location.items():
    print(f"\n{cell_type}:")
    
    # Try strict filtering first
    de_filtered = de_df[
        (de_df['is_de_fdr_0.05']) & 
        (abs(de_df['lfc_mean']) > 0.5)
    ]
    
    print(f"  Strict filtering (FDR < 0.05, |LFC| > 0.5): {len(de_filtered)} genes")

    
    # Sort by effect size (absolute log fold change)
    de_filtered = de_filtered.sort_values('lfc_mean', key=abs, ascending=False)
    
    filtered_de_results_location[cell_type] = de_filtered

print("\n" + "="*70)
print("SAVING RESULTS")
print("="*70)

# Save individual cell type results
for cell_type, de_df in filtered_de_results_location.items():
    filename = f"DEG_{cell_type.replace(' ', '_')}_left_vs_right.csv"
    de_df.to_csv(filename)
    print(f"Saved: {filename}")

# Create summary table
summary_data_location = []

for cell_type, de_df in filtered_de_results_location.items():
    up_in_left = (de_df['lfc_mean'] > 0).sum()
    down_in_left = (de_df['lfc_mean'] < 0).sum()
    
    summary_data_location.append({
        'Cell_Type': cell_type,
        'Total_DEGs': len(de_df),
        'Upregulated_in_Left': up_in_left,
        'Downregulated_in_Left': down_in_left,
        'Upregulated_in_Right': down_in_left,  # genes down in left = up in right
        'Downregulated_in_Right': up_in_left,  # genes up in left = down in right
        'Top_Up_in_Left_Gene': de_df[de_df['lfc_mean'] > 0].index[0] if (de_df['lfc_mean'] > 0).any() else 'NA',
        'Top_Down_in_Left_Gene': de_df[de_df['lfc_mean'] < 0].index[0] if (de_df['lfc_mean'] < 0).any() else 'NA'
    })

summary_df_location = pd.DataFrame(summary_data_location)
summary_df_location.to_csv('DEG_summary_left_vs_right_all_celltypes.csv', index=False)
print("\nSaved: DEG_summary_left_vs_right_all_celltypes.csv")

print("\n" + "="*70)
print("ANALYSIS SUMMARY")
print("="*70)
print(f"\nTotal cell types analyzed: {len(de_results_location)}")
print(f"Cell types with significant DEGs: {len([ct for ct, df in filtered_de_results_location.items() if len(df) > 0])}")
print(f"\nTotal unique DEGs across all cell types: {len(set().union(*[set(df.index) for df in filtered_de_results_location.values()]))}")

# Print summary table
print("\n" + "="*70)
print("DEG SUMMARY BY CELL TYPE")
print("="*70)
print(summary_df_location.to_string(index=False))

print("\n✓ Left vs Right DEG analysis complete!")
print("\nGenerated files:")
print("  - Individual CSV files per cell type: DEG_[CellType]_left_vs_right.csv")
print("  - Summary table: DEG_summary_left_vs_right_all_celltypes.csv")


Total tumor samples: 43640
Cell types: 21

Tumor location distribution:
location
Left     22848
Right    20792
Name: count, dtype: int64

Analyzing 21 cell types...

Analyzing: CD4+ T Cells
Left samples: 5789
Right samples: 6144
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 2
Total genes tested: 14469

Analyzing: CMS3
Left samples: 1313
Right samples: 1585
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 10
Total genes tested: 14469

Analyzing: Tip-like ECs
Left samples: 222
Right samples: 359
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 3
Total genes tested: 14469

Analyzing: CD8+ T cells
Left samples: 2649
Right samples: 2406
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 0
Total genes tested: 14469

Analyzing: B Cells
Left samples: 3365
Right samples: 3656
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 0
Total genes tested: 14469

Analyzing: Spp1+
Left samples: 791
Right samples: 821
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 3
Total genes tested: 14469

Analyzing: Mast cells
Left samples: 241
Right samples: 387
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 2
Total genes tested: 14469

Analyzing: Stromal 2
Left samples: 405
Right samples: 444
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 2
Total genes tested: 14469

Analyzing: CMS2
Left samples: 785
Right samples: 976
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 4
Total genes tested: 14469

Analyzing: Regulatory T Cells
Left samples: 2385
Right samples: 1151
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 0
Total genes tested: 14469

Analyzing: Pericytes
Left samples: 196
Right samples: 238
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 3
Total genes tested: 14469

Analyzing: Dendritic cells
Left samples: 1364
Right samples: 697
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 4
Total genes tested: 14469

Analyzing: Gamma delta T cells
Left samples: 348
Right samples: 415
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 2
Total genes tested: 14469

Analyzing: Helper 17 T cells
Left samples: 1809
Right samples: 458
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 0
Total genes tested: 14469

Analyzing: Mature Enterocytes type 2
Left samples: 47
Right samples: 158
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 4
Total genes tested: 14469

Analyzing: NK cells
Left samples: 652
Right samples: 357
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 0
Total genes tested: 14469

Analyzing: Plasma Cells
Left samples: 134
Right samples: 65
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 0
Total genes tested: 14469

Analyzing: Stromal 3
Left samples: 8
Right samples: 15
   Proceeding anyway, but results may be unreliable

[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 955
Total genes tested: 14469

Analyzing: Plasmacytoid Dendritic Cells
Left samples: 60
Right samples: 26
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 6
Total genes tested: 14469

Analyzing: Follicular helper T cells
Left samples: 276
Right samples: 423
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 2
Total genes tested: 14469

Analyzing: Enteric glia cells
Left samples: 9
Right samples: 11
   Proceeding anyway, but results may be unreliable

[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


DE...:   0%|          | 0/1 [00:00<?, ?it/s]

Significant DEGs (FDR < 0.05): 4
Total genes tested: 14469


FILTERING SIGNIFICANT DEGs

CD4+ T Cells:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 2 genes

CMS3:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 10 genes

Tip-like ECs:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 2 genes

CD8+ T cells:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 0 genes

B Cells:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 0 genes

Spp1+:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 3 genes

Mast cells:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 2 genes

Stromal 2:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 1 genes

CMS2:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 4 genes

Regulatory T Cells:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 0 genes

Pericytes:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 3 genes

Dendritic cells:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 4 genes

Gamma delta T cells:
  Strict filtering (FDR < 0.05, |LFC| > 0.5): 2 genes

Helper 17 T cells:
  Strict filte

In [7]:
# ============================================================================
# VISUALIZATIONS: LEFT vs RIGHT CRC DEG ANALYSIS
# ============================================================================


print("\n" + "="*70)
print("GENERATING VISUALIZATIONS")
print("="*70)

# 1. Summary barplot
print("\n1. Creating summary barplot...")
fig, ax = plt.subplots(figsize=(10, 6))
summary_df_location.plot(
    x='Cell_Type',
    y=['Upregulated_in_Left', 'Downregulated_in_Left'],
    kind='bar',
    ax=ax,
    color=['#2ca02c', '#ff7f0e']
)
plt.xlabel('Cell Type', fontsize=12)
plt.ylabel('Number of DEGs', fontsize=12)
plt.title('DEGs: Left vs Right CRC by Cell Type', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.legend(['Up in Left', 'Down in Left (Up in Right)'])
plt.tight_layout()
plt.savefig('DEG_summary_barplot_left_vs_right.png', dpi=300)
plt.close()
print("✓ Saved: DEG_summary_barplot_left_vs_right.png")

# 2. Volcano plots for each cell type
print("\n2. Creating volcano plots...")
for cell_type, de_df in filtered_de_results_location.items():
    if len(de_df) == 0:
        continue
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # All genes for this cell type
    all_genes = de_results_location[cell_type]
    
    # Create volcano plot data
    x = all_genes['lfc_mean']
    y = -np.log10(all_genes['proba_not_de'] + 1e-10)
    
    # Separate significant genes by direction
    sig_up = all_genes.index.isin(de_df.index) & (all_genes['lfc_mean'] > 0)
    sig_down = all_genes.index.isin(de_df.index) & (all_genes['lfc_mean'] < 0)
    non_sig = ~all_genes.index.isin(de_df.index)
    
    # Plot non-significant genes
    ax.scatter(x[non_sig], y[non_sig], c='lightgray', alpha=0.5, s=10, label='Not significant')
    
    # Plot significant genes
    ax.scatter(x[sig_up], y[sig_up], c='#2ca02c', alpha=0.6, s=15, label='Upregulated in Left')
    ax.scatter(x[sig_down], y[sig_down], c='#ff7f0e', alpha=0.6, s=15, label='Upregulated in Right')
    
    # Add labels for top genes
    top_genes = de_df.head(10)
    for gene in top_genes.index:
        ax.annotate(
            gene,
            xy=(all_genes.loc[gene, 'lfc_mean'], -np.log10(all_genes.loc[gene, 'proba_not_de'] + 1e-10)),
            fontsize=8,
            alpha=0.8
        )
    
    # Add reference lines
    ax.axhline(y=-np.log10(0.05), color='black', linestyle='--', linewidth=0.5, alpha=0.5)
    ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
    ax.axvline(x=0.5, color='black', linestyle='--', linewidth=0.5, alpha=0.5)
    ax.axvline(x=-0.5, color='black', linestyle='--', linewidth=0.5, alpha=0.5)
    
    ax.set_xlabel('Log Fold Change (Left vs Right)', fontsize=12)
    ax.set_ylabel('-log10(proba_not_de)', fontsize=12)
    ax.set_title(f'Volcano Plot: {cell_type}\n(Left vs Right CRC)', fontsize=14, fontweight='bold')
    ax.legend(loc='best', frameon=True, fancybox=True, shadow=True)
    
    plt.tight_layout()
    plt.savefig(f'volcano_left_right_{cell_type.replace(" ", "_")}.png', dpi=300, bbox_inches='tight')
    plt.close()

print(f"✓ Created {len(filtered_de_results_location)} volcano plots")

# ============================================================================
# HEATMAP VISUALIZATIONS
# ============================================================================
print("\n3. Generating heatmaps...")

# --- Heatmap 1: Top DEGs per cell type (individual heatmap per cell type) ---
print("\n3a. Individual cell type heatmaps...")
for cell_type, de_df in filtered_de_results_location.items():
    if len(de_df) == 0:
        continue
    
    # Get top genes by absolute log fold change
    de_df['abs_lfc'] = abs(de_df['lfc_mean'])
    n_genes = min(50, len(de_df))
    top_genes = de_df.nlargest(n_genes, 'abs_lfc')
    
    # Subset adata for this cell type and these genes
    adata_subset = adata_tumor[adata_tumor.obs['Cell_Type'] == cell_type, top_genes.index].copy()
    
    if adata_subset.n_obs == 0:
        continue
    
    # Calculate mean expression per location
    left_cells = adata_subset[adata_subset.obs[location_col] == 'Left']
    right_cells = adata_subset[adata_subset.obs[location_col] == 'Right']
    
    # Get mean expression
    if 'X_scvi' in adata.obsm:
        # Use scVI denoised expression
        left_mean = pd.DataFrame(
            model.get_normalized_expression(left_cells, return_mean=True),
            index=top_genes.index,
            columns=['Left']
        )
        right_mean = pd.DataFrame(
            model.get_normalized_expression(right_cells, return_mean=True),
            index=top_genes.index,
            columns=['Right']
        )
    else:
        # Fallback to regular mean
        if hasattr(left_cells.X, 'toarray'):
            left_data = left_cells.X.toarray().mean(axis=0)
            right_data = right_cells.X.toarray().mean(axis=0)
        else:
            left_data = left_cells.X.mean(axis=0)
            right_data = right_cells.X.mean(axis=0)
        
        # Ensure correct shape
        if len(left_data.shape) > 1:
            left_data = left_data.flatten()
            right_data = right_data.flatten()
        
        left_mean = pd.DataFrame(left_data, index=top_genes.index, columns=['Left'])
        right_mean = pd.DataFrame(right_data, index=top_genes.index, columns=['Right'])
    
    # Combine
    heatmap_data = pd.concat([left_mean, right_mean], axis=1)
    
    # Z-score normalize
    heatmap_data_zscore = heatmap_data.T
    heatmap_data_zscore = (heatmap_data_zscore - heatmap_data_zscore.mean(axis=0)) / (heatmap_data_zscore.std(axis=0) + 1e-10)
    heatmap_data_zscore = heatmap_data_zscore.T
    
    # Create heatmap
    fig, ax = plt.subplots(figsize=(6, max(8, n_genes * 0.3)))
    sns.heatmap(
        heatmap_data_zscore,
        cmap='RdBu_r',
        center=0,
        cbar_kws={'label': 'Z-score'},
        yticklabels=True,
        xticklabels=True,
        linewidths=0.5,
        ax=ax
    )
    
    ax.set_title(f'Top {n_genes} DEGs: {cell_type}\n(Left vs Right CRC)', fontsize=12, fontweight='bold')
    ax.set_xlabel('Tumor Location', fontsize=10)
    ax.set_ylabel('Genes', fontsize=10)
    
    plt.tight_layout()
    plt.savefig(f'heatmap_top_genes_left_right_{cell_type.replace(" ", "_")}.png', dpi=300, bbox_inches='tight')
    plt.close()

print("✓ Individual cell type heatmaps created")

# --- Heatmap 2: Combined heatmap showing LFC across all cell types ---
print("\n3b. Combined LFC heatmap...")
lfc_matrix_location = pd.DataFrame()
for cell_type, de_df in filtered_de_results_location.items():
    if len(de_df) == 0:
        continue
    lfc_matrix_location[cell_type] = de_df['lfc_mean']

# Fill NaN with 0
lfc_matrix_location = lfc_matrix_location.fillna(0)

# Get top genes by variance
if len(lfc_matrix_location) > 100:
    gene_variance = lfc_matrix_location.var(axis=1)
    top_genes_idx = gene_variance.nlargest(100).index
    lfc_matrix_top = lfc_matrix_location.loc[top_genes_idx]
else:
    lfc_matrix_top = lfc_matrix_location

# Create combined heatmap
if len(lfc_matrix_top) > 0 and len(lfc_matrix_top.columns) > 0:
    fig, ax = plt.subplots(figsize=(max(10, len(lfc_matrix_top.columns) * 0.8), 
                                     max(12, len(lfc_matrix_top) * 0.2)))
    
    sns.heatmap(
        lfc_matrix_top,
        cmap='RdBu_r',
        center=0,
        cbar_kws={'label': 'Log Fold Change (Left vs Right)'},
        yticklabels=True,
        xticklabels=True,
        linewidths=0.5,
        vmin=-3,
        vmax=3,
        ax=ax
    )
    
    ax.set_title('Log Fold Changes Across Cell Types: Left vs Right CRC\n(Green = Upregulated in Left, Orange = Upregulated in Right)', 
                 fontsize=14, fontweight='bold', pad=20)
    ax.set_xlabel('Cell Type', fontsize=12)
    ax.set_ylabel('Genes', fontsize=12)
    
    plt.tight_layout()
    plt.savefig('heatmap_lfc_all_celltypes_left_right.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Combined LFC heatmap created")

# --- Heatmap 3: Combined LFC and significance ---
print("\n3c. Combined LFC and significance heatmap...")
lfc_matrix_all_location = pd.DataFrame()
sig_matrix_all_location = pd.DataFrame()

for cell_type, de_df in de_results_location.items():
    lfc_matrix_all_location[cell_type] = de_df['lfc_mean']
    sig_matrix_all_location[cell_type] = -np.log10(de_df['proba_not_de'] + 1e-10)

# Get genes that are significant in at least one cell type
sig_genes_location = set()
for cell_type, de_df in filtered_de_results_location.items():
    sig_genes_location.update(de_df.index)

if len(sig_genes_location) > 0:
    lfc_matrix_sig = lfc_matrix_all_location.loc[list(sig_genes_location)]
    sig_matrix_sig = sig_matrix_all_location.loc[list(sig_genes_location)]
    
    # Limit to top 100 if too many
    if len(lfc_matrix_sig) > 100:
        gene_variance = lfc_matrix_sig.var(axis=1)
        top_genes_idx = gene_variance.nlargest(100).index
        lfc_matrix_sig = lfc_matrix_sig.loc[top_genes_idx]
        sig_matrix_sig = sig_matrix_sig.loc[top_genes_idx]
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(max(16, len(lfc_matrix_sig.columns) * 1.5), 
                                                    max(10, len(lfc_matrix_sig) * 0.2)))
    
    # Heatmap 1: Log Fold Change
    sns.heatmap(
        lfc_matrix_sig,
        cmap='RdBu_r',
        center=0,
        cbar_kws={'label': 'Log Fold Change'},
        yticklabels=True,
        xticklabels=True,
        linewidths=0.5,
        vmin=-3,
        vmax=3,
        ax=ax1
    )
    ax1.set_title('Log Fold Change', fontsize=12, fontweight='bold')
    ax1.set_xlabel('Cell Type', fontsize=10)
    ax1.set_ylabel('Genes', fontsize=10)
    
    # Heatmap 2: Significance
    sns.heatmap(
        sig_matrix_sig,
        cmap='YlOrRd',
        cbar_kws={'label': '-log10(proba_not_de)'},
        yticklabels=True,
        xticklabels=True,
        linewidths=0.5,
        ax=ax2
    )
    ax2.set_title('Significance', fontsize=12, fontweight='bold')
    ax2.set_xlabel('Cell Type', fontsize=10)
    ax2.set_ylabel('', fontsize=10)
    
    fig.suptitle('DEG Analysis Across Cell Types: Left vs Right CRC', 
                 fontsize=14, fontweight='bold', y=1.00)
    
    plt.tight_layout()
    plt.savefig('heatmap_lfc_and_significance_left_right.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Combined LFC and significance heatmap created")

# --- Heatmap 4: Cell type comparison (top DEGs per cell type across all cell types) ---
print("\n3d. Cell type comparison heatmap...")

celltype_top_genes_location = {}
genes_per_celltype = 5

for cell_type, de_df in filtered_de_results_location.items():
    if len(de_df) == 0:
        continue
    
    de_df['abs_lfc'] = abs(de_df['lfc_mean'])
    n_genes = min(genes_per_celltype, len(de_df))
    top_genes = de_df.nlargest(n_genes, 'abs_lfc')
    celltype_top_genes_location[cell_type] = top_genes.index.tolist()

# Combine all unique genes
all_top_genes_location = []
for genes in celltype_top_genes_location.values():
    all_top_genes_location.extend(genes)
all_top_genes_location = list(set(all_top_genes_location))

print(f"Total unique genes selected: {len(all_top_genes_location)}")

# Create comparison matrix
lfc_comparison_matrix_location = pd.DataFrame(index=all_top_genes_location)

for cell_type in de_results_location.keys():
    if cell_type in celltype_top_genes_location:
        cell_de = de_results_location[cell_type]
        lfc_comparison_matrix_location[cell_type] = cell_de.loc[all_top_genes_location, 'lfc_mean']

lfc_comparison_matrix_location = lfc_comparison_matrix_location.fillna(0)

# Sort genes by cell type
gene_to_celltype = {}
for cell_type, genes in celltype_top_genes_location.items():
    for gene in genes:
        if gene not in gene_to_celltype:
            gene_to_celltype[gene] = cell_type

sorted_genes = []
for cell_type in lfc_comparison_matrix_location.columns:
    if cell_type in celltype_top_genes_location:
        genes_for_this_type = [g for g in celltype_top_genes_location[cell_type] 
                               if g in lfc_comparison_matrix_location.index]
        sorted_genes.extend(genes_for_this_type)

remaining_genes = [g for g in lfc_comparison_matrix_location.index if g not in sorted_genes]
sorted_genes.extend(remaining_genes)

lfc_comparison_matrix_location = lfc_comparison_matrix_location.loc[sorted_genes]

# Create heatmap
fig, ax = plt.subplots(figsize=(max(12, len(lfc_comparison_matrix_location.columns) * 1.2), 
                                 max(10, len(lfc_comparison_matrix_location) * 0.25)))

sns.heatmap(
    lfc_comparison_matrix_location,
    cmap='RdBu_r',
    center=0,
    cbar_kws={'label': 'Log Fold Change (Left vs Right)', 'shrink': 0.8},
    yticklabels=True,
    xticklabels=True,
    linewidths=0.3,
    linecolor='gray',
    vmin=-3,
    vmax=3,
    ax=ax
)

ax.set_title(f'Top {genes_per_celltype} DEGs per Cell Type: Expression Across All Cell Types\n' +
             '(Left vs Right CRC)', 
             fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel('Cell Type', fontsize=12, fontweight='bold')
ax.set_ylabel('Genes (grouped by cell type)', fontsize=12, fontweight='bold')

plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig('heatmap_celltype_deg_comparison_left_right.png', dpi=300, bbox_inches='tight')
plt.close()

print("✓ Cell type comparison heatmap created")

print("\n" + "="*70)
print("✓ ALL VISUALIZATIONS COMPLETE!")
print("="*70)
print("\nGenerated files:")
print("  - DEG_summary_barplot_left_vs_right.png")
print("  - Individual volcano plots per cell type: volcano_left_right_[CellType].png")
print("  - Individual heatmaps per cell type: heatmap_top_genes_left_right_[CellType].png")
print("  - Combined LFC heatmap: heatmap_lfc_all_celltypes_left_right.png")
print("  - Combined LFC + significance heatmap: heatmap_lfc_and_significance_left_right.png")
print("  - Cell type comparison heatmap: heatmap_celltype_deg_comparison_left_right.png")


GENERATING VISUALIZATIONS

1. Creating summary barplot...
✓ Saved: DEG_summary_barplot_left_vs_right.png

2. Creating volcano plots...
✓ Created 21 volcano plots

3. Generating heatmaps...

3a. Individual cell type heatmaps...
✓ Individual cell type heatmaps created

3b. Combined LFC heatmap...
✓ Combined LFC heatmap created

3c. Combined LFC and significance heatmap...
✓ Combined LFC and significance heatmap created

3d. Cell type comparison heatmap...
Total unique genes selected: 22
✓ Cell type comparison heatmap created

✓ ALL VISUALIZATIONS COMPLETE!

Generated files:
  - DEG_summary_barplot_left_vs_right.png
  - Individual volcano plots per cell type: volcano_left_right_[CellType].png
  - Individual heatmaps per cell type: heatmap_top_genes_left_right_[CellType].png
  - Combined LFC heatmap: heatmap_lfc_all_celltypes_left_right.png
  - Combined LFC + significance heatmap: heatmap_lfc_and_significance_left_right.png
  - Cell type comparison heatmap: heatmap_celltype_deg_comparison