# Essential Genes Analysis

This notebook creates subplots for eight essential genes columns, grouped by <10 uniqueDiseases and >=10 uniqueDiseases.


In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from typing import List, Dict, Any

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

In [None]:
def create_sample_essential_genes_data(n_genes: int = 1000) -> pd.DataFrame:
    """Create sample essential genes data with 8 essential gene columns and disease information.
    
    Args:
        n_genes: Number of genes to simulate
        
    Returns:
        DataFrame with gene data and essential gene scores
    """
    np.random.seed(42)  # For reproducibility
    
    # Create gene IDs
    gene_ids = [f"ENSG{i:08d}" for i in range(n_genes)]
    
    # Create 8 essential gene columns with different scoring methods
    essential_gene_columns = {
        'essential_score_crispr': np.random.beta(2, 5, n_genes),  # CRISPR essentiality scores
        'essential_score_rnai': np.random.beta(2, 4, n_genes),    # RNAi essentiality scores
        'essential_score_combined': np.random.beta(3, 3, n_genes), # Combined essentiality
        'essential_rank_crispr': np.random.randint(1, n_genes+1, n_genes), # CRISPR ranking
        'essential_rank_rnai': np.random.randint(1, n_genes+1, n_genes),   # RNAi ranking
        'essential_probability': np.random.beta(1, 3, n_genes),    # Probability of being essential
        'cell_viability_score': np.random.beta(2, 3, n_genes),     # Cell viability impact
        'fitness_score': np.random.beta(2, 4, n_genes)             # Fitness score
    }
    
    # Create disease count data (some genes associated with many diseases, others with few)
    # Use a power law distribution to simulate realistic disease association patterns
    unique_diseases = np.random.pareto(1.5, n_genes) * 2 + 1
    unique_diseases = np.round(unique_diseases).astype(int)
    unique_diseases = np.clip(unique_diseases, 1, 50)  # Cap at reasonable max
    
    # Create the dataframe
    data = {
        'geneId': gene_ids,
        'uniqueDiseases': unique_diseases,
        **essential_gene_columns
    }
    
    return pd.DataFrame(data)

In [None]:
def plot_essential_genes_by_disease_groups(df: pd.DataFrame, save_path: str = None) -> None:
    """Create subplots for each essential gene column, grouped by disease count.
    
    Args:
        df: DataFrame with essential gene data
        save_path: Optional path to save the plot
    """
    # Get essential gene columns (exclude geneId and uniqueDiseases)
    essential_cols = [col for col in df.columns if col.startswith('essential_') or col.endswith('_score')]
    
    # Create disease groups
    df['disease_group'] = df['uniqueDiseases'].apply(lambda x: '<10 diseases' if x < 10 else '>=10 diseases')
    
    # Set up the subplot grid (4x2 for 8 columns)
    fig, axes = plt.subplots(4, 2, figsize=(15, 20))
    fig.suptitle('Essential Genes Analysis by Disease Group', fontsize=16, y=0.98)
    
    # Flatten axes for easier iteration
    axes_flat = axes.flatten()
    
    for i, col in enumerate(essential_cols[:8]):  # Ensure we only plot 8 columns
        ax = axes_flat[i]
        
        # Create data for each group
        group_data = []
        group_labels = []
        
        for group in ['<10 diseases', '>=10 diseases']:
            group_values = df[df['disease_group'] == group][col].dropna()
            if len(group_values) > 0:
                group_data.append(group_values)
                group_labels.append(f'{group}\n(n={len(group_values)})')
        
        # Create box plot or histogram based on data type
        if 'rank' in col.lower():
            # For ranking columns, use box plots
            bp = ax.boxplot(group_data, labels=group_labels, patch_artist=True)
            colors = ['lightblue', 'lightcoral']
            for patch, color in zip(bp['boxes'], colors):
                patch.set_facecolor(color)
            ax.set_ylabel('Rank')
        else:
            # For score columns, use histograms
            colors = ['blue', 'red']
            alpha = 0.6
            for j, (data, label) in enumerate(zip(group_data, group_labels)):
                ax.hist(data, bins=30, alpha=alpha, color=colors[j], 
                       label=label, density=True)
            ax.set_ylabel('Density')
            ax.legend()
        
        # Set title and labels
        ax.set_title(col.replace('_', ' ').title(), fontsize=12, pad=10)
        ax.set_xlabel('Value' if 'rank' not in col.lower() else 'Group')
        ax.grid(True, alpha=0.3)
        
        # Add statistical info
        if len(group_data) == 2:
            from scipy import stats
            stat, p_value = stats.mannwhitneyu(group_data[0], group_data[1])
            ax.text(0.02, 0.98, f'p-value: {p_value:.3f}', 
                   transform=ax.transAxes, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
    # Print summary statistics
    print("\n=== Summary Statistics ===")
    print(f"Total genes: {len(df)}")
    print(f"Genes with <10 diseases: {len(df[df['uniqueDiseases'] < 10])} ({len(df[df['uniqueDiseases'] < 10])/len(df)*100:.1f}%)")
    print(f"Genes with >=10 diseases: {len(df[df['uniqueDiseases'] >= 10])} ({len(df[df['uniqueDiseases'] >= 10])/len(df)*100:.1f}%)")
    print(f"Mean unique diseases: {df['uniqueDiseases'].mean():.2f}")
    print(f"Median unique diseases: {df['uniqueDiseases'].median():.2f}")

In [None]:
# Generate sample data
print("Generating sample essential genes data...")
essential_genes_df = create_sample_essential_genes_data(n_genes=1000)

# Display first few rows
print("\nSample data:")
essential_genes_df.head()

In [None]:
# Create the essential genes analysis plots
print("Creating essential genes analysis plots...")
plot_essential_genes_by_disease_groups(essential_genes_df)

In [None]:
# Additional analysis: correlation between essential genes columns and disease associations
print("\n=== Correlation Analysis ===")

# Calculate correlations
essential_cols = [col for col in essential_genes_df.columns if col.startswith('essential_') or col.endswith('_score')]
corr_data = essential_genes_df[essential_cols + ['uniqueDiseases']].corr()

# Plot correlation heatmap
plt.figure(figsize=(12, 10))
sns.heatmap(corr_data, annot=True, cmap='coolwarm', center=0, 
            square=True, fmt='.3f')
plt.title('Correlation Matrix: Essential Gene Scores and Disease Associations')
plt.tight_layout()
plt.show()

# Print correlations with uniqueDiseases
disease_corr = corr_data['uniqueDiseases'].drop('uniqueDiseases').sort_values(key=abs, ascending=False)
print("\nCorrelations with unique diseases count:")
for col, corr in disease_corr.items():
    print(f"{col}: {corr:.3f}")