In [1]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

# Set paths
combined_data_file = "a549_combined_data.h5ad"
results_dir = "a549_perturbation_analysis"
os.makedirs(results_dir, exist_ok=True)
fig_dir = os.path.join(results_dir, "figures")
os.makedirs(fig_dir, exist_ok=True)

# Load the previously saved combined data
print(f"Loading combined dataset from {combined_data_file}...")
combined = sc.read_h5ad(combined_data_file)
print(f"Combined data shape: {combined.shape}")

# Function to sanitize filenames
def sanitize_filename(filename):
    """Replace invalid filename characters with underscores"""
    invalid_chars = r'<>:"/\|?*'
    for char in invalid_chars:
        filename = filename.replace(char, '_')
    return filename

# Verify control samples
print(f"Control samples: {combined.obs['is_control'].sum()}")
print(f"Treatment samples: {(~combined.obs['is_control']).sum()}")

# Get unique drugs (excluding control)
unique_drugs = combined.obs['drug'].unique()
unique_drugs = [drug for drug in unique_drugs if drug != 'control']

print(f"Analyzing differential expression for {len(unique_drugs)} drugs...")

# Create results dataframe to store findings
results = pd.DataFrame()

for drug in unique_drugs:
    print(f"Processing drug: {drug}")
    
    # Get cells treated with this drug
    drug_cells = combined[combined.obs['drug'] == drug]
    
    if len(drug_cells) < 10:
        print(f"  Skipping {drug}: too few cells ({len(drug_cells)})")
        continue
    
    # Perform differential expression analysis
    try:
        sc.tl.rank_genes_groups(combined, 'drug', groups=[drug], reference='control', method='wilcoxon')
        
        # Extract results for this drug
        de_genes = sc.get.rank_genes_groups_df(combined, group=drug)
        de_genes['drug'] = drug
        
        # Filter for significantly differentially expressed genes
        significant_genes = de_genes[de_genes['pvals_adj'] < 0.05]
        
        # Add to results
        results = pd.concat([results, significant_genes])
        
        # Create volcano plot for top genes
        plt.figure(figsize=(10, 8))
        
        # Add small constant to avoid log(0)
        plt.scatter(
            de_genes['logfoldchanges'], 
            -np.log10(de_genes['pvals'].replace(0, 1e-300)),
            alpha=0.5
        )
        
        # Highlight significant genes
        significant = (de_genes['pvals_adj'] < 0.05)
        plt.scatter(
            de_genes.loc[significant, 'logfoldchanges'],
            -np.log10(de_genes.loc[significant, 'pvals'].replace(0, 1e-300)),
            color='red', alpha=0.8
        )
        
        # Label top genes
        top_genes = de_genes.nsmallest(10, 'pvals')
        for _, gene in top_genes.iterrows():
            pval = gene['pvals']
            if pval == 0:
                pval = 1e-300  # Avoid log(0)
            
            plt.annotate(
                gene['names'], 
                (gene['logfoldchanges'], -np.log10(pval)),
                xytext=(5, 5), 
                textcoords='offset points'
            )
        
        plt.axhline(-np.log10(0.05), linestyle='--', color='gray')
        plt.axvline(-1, linestyle='--', color='gray')
        plt.axvline(1, linestyle='--', color='gray')
        
        plt.xlabel('Log Fold Change')
        plt.ylabel('-log10(p-value)')
        plt.title(f'Differential Expression: {drug} vs Control')
        
        # Use sanitized filename
        plt.savefig(os.path.join(fig_dir, f'volcano_plot_{sanitize_filename(drug)}.png'))
        plt.close()
        
        # Save top genes list for this drug
        top_n = min(50, len(significant_genes))
        significant_genes.head(top_n).to_csv(
            os.path.join(results_dir, f'top_genes_{sanitize_filename(drug)}.csv'), index=False)
        
        print(f"  Found {len(significant_genes)} significantly affected genes")
        
    except Exception as e:
        print(f"  Error analyzing {drug}: {e}")

# Save combined results
if not results.empty:
    # Summary of all drugs and their significant genes
    drug_gene_counts = results.groupby('drug').size().reset_index(name='sig_gene_count')
    drug_gene_counts = drug_gene_counts.sort_values('sig_gene_count', ascending=False)
    
    # Save summary
    drug_gene_counts.to_csv(os.path.join(results_dir, 'drug_affected_gene_counts.csv'), index=False)
    
    # Save full results
    results.to_csv(os.path.join(results_dir, 'all_drug_gene_effects.csv'), index=False)
    
    # Create summary figure of drugs by number of affected genes
    plt.figure(figsize=(12, 8))
    sns.barplot(x='drug', y='sig_gene_count', data=drug_gene_counts)
    plt.xticks(rotation=90)
    plt.title('Number of Significantly Affected Genes by Drug')
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, 'drug_gene_count_summary.png'))
    plt.close()
    
    # Identify most frequently affected genes across multiple drugs
    gene_drug_counts = results.groupby('names').size().reset_index(name='drug_count')
    gene_drug_counts = gene_drug_counts.sort_values('drug_count', ascending=False)
    
    # Save genes affected by multiple drugs
    gene_drug_counts.head(100).to_csv(os.path.join(results_dir, 'multi_drug_affected_genes.csv'), index=False)
    
    # Create heatmap of top genes across drugs
    top_genes = gene_drug_counts.head(20)['names'].tolist()
    top_drugs = drug_gene_counts.head(15)['drug'].tolist()
    
    # Filter results for top genes and drugs
    heatmap_data = results[
        (results['names'].isin(top_genes)) & 
        (results['drug'].isin(top_drugs))
    ]
    
    if not heatmap_data.empty:
        # Create pivot table for heatmap
        pivot_data = heatmap_data.pivot_table(
            index='names', 
            columns='drug', 
            values='logfoldchanges',
            fill_value=0
        )
        
        # Create heatmap
        plt.figure(figsize=(15, 10))
        sns.heatmap(pivot_data, cmap='RdBu_r', center=0, annot=False)
        plt.title('Log Fold Changes of Top Genes Across Drugs')
        plt.tight_layout()
        plt.savefig(os.path.join(fig_dir, 'gene_drug_heatmap.png'))
        plt.close()
    
    print("\nTop 10 genes affected by multiple drugs:")
    print(gene_drug_counts.head(10))
    
    print("\nTop 10 drugs by number of affected genes:")
    print(drug_gene_counts.head(10))
    
    # Save a summary report
    with open(os.path.join(results_dir, 'analysis_summary.txt'), 'w') as f:
        f.write("A549 Perturbation Analysis Summary\n")
        f.write("==================================\n\n")
        f.write(f"Total cells analyzed: {combined.shape[0]}\n")
        f.write(f"Control cells: {combined.obs['is_control'].sum()}\n")
        f.write(f"Treatment cells: {(~combined.obs['is_control']).sum()}\n\n")
        f.write(f"Unique drugs analyzed: {len(unique_drugs)}\n\n")
        f.write("Top 10 drugs by number of affected genes:\n")
        for _, row in drug_gene_counts.head(10).iterrows():
            f.write(f"- {row['drug']}: {row['sig_gene_count']} genes\n")
        
        f.write("\nTop 10 genes affected by multiple drugs:\n")
        for _, row in gene_drug_counts.head(10).iterrows():
            f.write(f"- {row['names']}: affected by {row['drug_count']} drugs\n")
        
else:
    print("No significant results found.")

print(f"Analysis complete. Results saved to {results_dir}")

Loading combined dataset from a549_combined_data.h5ad...
Combined data shape: (246262, 33388)
Control samples: 20617
Treatment samples: 225645
Analyzing differential expression for 21 drugs...
Processing drug: wild-type (wt) virus
  Found 13180 significantly affected genes
Processing drug: irradiated a549 cells (6 gy γ-
  Found 14516 significantly affected genes
Processing drug: infected (cal07, 16 hours, rep
  Found 14188 significantly affected genes
Processing drug: car t cell therapy with suv39h
  Found 16322 significantly affected genes
Processing drug: infected with h3n2 (a/perth/16
  Found 16950 significantly affected genes
Processing drug: 8 hours post infection
  Found 12955 significantly affected genes
Processing drug: irradiation
  Found 15578 significantly affected genes
Processing drug: ritonavir, gemcitabine, cispla
  Found 14251 significantly affected genes
Processing drug: glyconanomaterials for combati
  Found 14101 significantly affected genes
Processing drug: ns1 4xst