In [6]:
import os
import numpy as np
import pandas as pd
from scipy.stats import mannwhitneyu
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import boto3
from dotenv import load_dotenv
import anndata as ad
import shutil
from datetime import datetime

# Configuration constants
PVAL_THRESHOLD = 0.01
LOGFC_THRESHOLD = 10

print(f"Configuration: p < {PVAL_THRESHOLD}, |logFC| > {LOGFC_THRESHOLD}, using median")


Configuration: p < 0.01, |logFC| > 10, using median


In [7]:
def parse_s3_uri(s3_uri):
    """Parse S3 URI to extract bucket and key"""
    if not s3_uri.startswith('s3://'):
        raise ValueError("Invalid S3 URI. Must start with 's3://'")
    
    # Remove 's3://' prefix
    path = s3_uri[5:]
    
    # Split into bucket and key
    parts = path.split('/', 1)
    bucket = parts[0]
    key = parts[1] if len(parts) > 1 else ''
    
    return bucket, key

def download_from_s3(bucket, key, local_path, aws_access_key_id, aws_secret_access_key):
    """Download file from S3 using provided credentials"""
    # Create a session using provided credentials
    session = boto3.Session(
        aws_access_key_id=aws_access_key_id,
        aws_secret_access_key=aws_secret_access_key,
    )

    s3 = session.client('s3')
    try:
        s3.download_file(bucket, key, local_path)
        return True
    except ClientError as e:
        print(f"Error downloading from S3: {e}")
        return False

def load_config():
    """Load configuration from .env file in parent directory"""
    load_dotenv('../.env')
    
    config = {
        'aws_access_key_id': os.getenv('AWS_ACCESS_KEY_ID'),
        'aws_secret_access_key': os.getenv('AWS_SECRET_ACCESS_KEY'),
        'data_uri': os.getenv('DATA_URI')
    }
    
    # Validate that all required variables are present
    missing_vars = [key for key, value in config.items() if not value]
    if missing_vars:
        raise ValueError(f"Missing required environment variables: {missing_vars}")
    
    return config

def check_existing_data(data_dir):
    """Check what data files already exist in the data directory"""
    if not os.path.exists(data_dir):
        print(f"📁 Data directory does not exist: {data_dir}")
        return []
    
    files = [f for f in os.listdir(data_dir) if f.endswith('.h5ad')]
    
    if files:
        print(f"📁 Found {len(files)} existing data file(s) in {data_dir}:")
        for file in files:
            file_path = os.path.join(data_dir, file)
            file_size = os.path.getsize(file_path)
            print(f"   • {file} ({file_size:,} bytes, {file_size/1024/1024:.1f} MB)")
    else:
        print(f"📁 No .h5ad files found in {data_dir}")
    
    return files

def download_data(data_dir, force_download=False):
    """Main function to orchestrate the data download"""
    # Load configuration
    config = load_config()
    
    # Create data directory if it doesn't exist
    os.makedirs(data_dir, exist_ok=True)
    print(f"Data directory: {os.path.abspath(data_dir)}")
    
    # Parse the DATA_URI and download the file
    bucket, key = parse_s3_uri(config['data_uri'])
    
    # Extract filename from S3 key or use default
    filename = os.path.basename(key) if key else 'test_data.h5ad'
    if not filename.endswith('.h5ad'):
        filename = 'test_data.h5ad'
    
    # Full path to save the file
    local_path = os.path.join(data_dir, filename)
    
    # Check if file already exists
    if os.path.exists(local_path) and not force_download:
        file_size = os.path.getsize(local_path)
        print(f"✅ Data file already exists: {os.path.abspath(local_path)}")
        print(f"   File size: {file_size:,} bytes ({file_size/1024/1024:.1f} MB)")
        print("   Skipping download. Use force_download=True to re-download.")
        return True, local_path
    
    # Download the file
    print(f"Downloading data from S3...")
    success = download_from_s3(
        bucket=bucket,
        key=key,
        local_path=local_path,
        aws_access_key_id=config['aws_access_key_id'],
        aws_secret_access_key=config['aws_secret_access_key']
    )
    
    if success:
        file_size = os.path.getsize(local_path)
        print(f"✅ Data download completed successfully!")
        print(f"   File saved to: {os.path.abspath(local_path)}")
        print(f"   File size: {file_size:,} bytes ({file_size/1024/1024:.1f} MB)")
    else:
        print("Data download failed!")
    
    return success, local_path if success else None

# Execute the download
data_dir = '../data'
print("=== DATA MANAGEMENT ===")

# Check existing data files first
existing_files = check_existing_data(data_dir)

# Download data (will skip if already exists)
success, data_file = download_data(data_dir, force_download=False)

# Optional: Force re-download if needed
# success, data_file = download_data(data_dir, force_download=True)

=== DATA MANAGEMENT ===
📁 Found 1 existing data file(s) in ../data:
   • test_data.h5ad (124,918,273 bytes, 119.1 MB)
Data directory: /Users/jmsung/projects/bio-ml/diff_exp/data
✅ Data file already exists: /Users/jmsung/projects/bio-ml/diff_exp/data/test_data.h5ad
   File size: 124,918,273 bytes (119.1 MB)
   Skipping download. Use force_download=True to re-download.


In [8]:
# Load and explore the h5ad file
if success and data_file:
    print("Loading data...")
    adata = ad.read_h5ad(data_file)
    
    print(f"Loaded data with shape: {adata.shape}")
    print(f"Number of genes (variables): {adata.n_vars}")
    print(f"Number of cells (observations): {adata.n_obs}")

    print("\n" + "="*50)
    print("DATA EXPLORATION")
    print("="*50)

    # Basic data structure
    print("\n1. Data Matrix Information:")
    print(f"   - Data type: {type(adata.X)}")
    print(f"   - Matrix shape: {adata.X.shape}")
    print(f"   - Data dtype: {adata.X.dtype}")

    # Gene information
    print("\n2. Gene (Variable) Information:")
    print(f"   - First 10 gene names: {adata.var_names[:10].tolist()}")
    print(f"   - Gene annotation columns: {adata.var.columns.tolist()}")
    if not adata.var.empty:
        print("   - Gene annotations preview:")
        print(adata.var.head())

    # Cell information  
    print("\n3. Cell (Observation) Information:")
    print(f"   - Cell annotation columns: {adata.obs.columns.tolist()}")
    print("   - Cell annotations preview:")
    print(adata.obs.head())

    # Check for condition/treatment information
    print("\n4. Condition/Treatment Analysis:")
    if 'condition' in adata.obs.columns:
        print(f"   - Found 'condition' column!")
        print(f"   - Conditions: {adata.obs['condition'].value_counts().to_dict()}")
    elif 'treatment' in adata.obs.columns:
        print(f"   - Found 'treatment' column!")
        print(f"   - Treatments: {adata.obs['treatment'].value_counts().to_dict()}")
    else:
        print("   - No 'condition' or 'treatment' column found")
        print("   - Available columns for grouping:")
        for col in adata.obs.columns:
            if adata.obs[col].dtype == 'object' or adata.obs[col].dtype.name == 'category':
                print(f"     * {col}: {adata.obs[col].unique()[:5]}")

    # Expression data summary
    print("\n5. Expression Data Summary:")
    if hasattr(adata.X, 'toarray'):
        # Sparse matrix
        sample_data = adata.X[:100, :100].toarray()
        print(f"   - Matrix type: Sparse")
        print(f"   - Sample values (first 5x5):")
        print(sample_data[:5, :5])
    else:
        # Dense matrix
        print(f"   - Matrix type: Dense")
        print(f"   - Sample values (first 5x5):")
        print(adata.X[:5, :5])

    print(f"   - Expression range: {adata.X.min():.3f} to {adata.X.max():.3f}")
    print(f"   - Mean expression: {adata.X.mean():.3f}")

    # Check for other data layers
    print("\n6. Additional Data Layers:")
    if adata.layers:
        print(f"   - Available layers: {list(adata.layers.keys())}")
    else:
        print("   - No additional layers found")

    if adata.obsm:
        print(f"   - Observation matrices (obsm): {list(adata.obsm.keys())}")
    else:
        print("   - No observation matrices found")

    if adata.varm:
        print(f"   - Variable matrices (varm): {list(adata.varm.keys())}")
    else:
        print("   - No variable matrices found")
        
else:
    print("❌ Cannot load data: Download failed or file path not available")
    print("Please run the download cell first")


Loading data...
Loaded data with shape: (5025, 61198)
Number of genes (variables): 61198
Number of cells (observations): 5025

DATA EXPLORATION

1. Data Matrix Information:
   - Data type: <class 'scipy.sparse._csr.csr_matrix'>
   - Matrix shape: (5025, 61198)
   - Data dtype: float32

2. Gene (Variable) Information:
   - First 10 gene names: ['DDX11L1', 'WASH7P', 'MIR6859-1', 'MIR1302-2HG', 'MIR1302-2', 'FAM138A', 'OR4G4P', 'OR4G11P', 'OR4F5', 'LOC100996442']
   - Gene annotation columns: []

3. Cell (Observation) Information:
   - Cell annotation columns: ['cell_id', 'group']
   - Cell annotations preview:
                 cell_id      group
7212  AAACCCAAGCCATTCA-1  treatment
7213  AAACCCAAGTGACACG-1    control
7214  AAACCCACAGCTGTCG-1    control
7215  AAACCCAGTATCTTCT-1    control
7216  AAACCCAGTCGAATGG-1    control

4. Condition/Treatment Analysis:
   - No 'condition' or 'treatment' column found
   - Available columns for grouping:
     * cell_id: ['AAACCCAAGCCATTCA-1' 'AAACCCAAGT

In [9]:
adata.X[:100, :100].toarray()


array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], shape=(100, 100), dtype=float32)

In [10]:
# Updated outlier detection function for symmetric percentile removal
def detect_outliers_percentile(expression_data, percentile_threshold=95):
    """
    Detect outliers using symmetric percentile-based method
    
    Parameters:
    expression_data: array of expression values
    percentile_threshold: central percentage to keep (e.g., 95 means keep middle 95%, 
                         remove bottom 2.5% and top 2.5%)
    
    Returns:
    Boolean mask where True indicates outlier
    """
    if len(expression_data) == 0:
        return np.array([], dtype=bool)
    
    # Calculate symmetric bounds
    # For 95%, remove bottom 2.5% and top 2.5%
    lower_percentile = (100 - percentile_threshold) / 2
    upper_percentile = 100 - lower_percentile
    
    lower_bound = np.percentile(expression_data, lower_percentile)
    upper_bound = np.percentile(expression_data, upper_percentile)
    
    # Identify outliers (values below lower bound OR above upper bound)
    outlier_mask = (expression_data < lower_bound) | (expression_data > upper_bound)
    
    return outlier_mask

print(f"Updated outlier detection function: symmetric percentile removal")
print(f"For 95% threshold: removes bottom 2.5% and top 2.5% of values")
print(f"For 99% threshold: removes bottom 0.5% and top 0.5% of values")


Updated outlier detection function: symmetric percentile removal
For 95% threshold: removes bottom 2.5% and top 2.5% of values
For 99% threshold: removes bottom 0.5% and top 0.5% of values


In [11]:
# Updated outlier detection function for symmetric percentile removal
def detect_outliers_percentile(expression_data, percentile_threshold=95):
    """
    Detect outliers using symmetric percentile-based method
    
    Parameters:
    expression_data: array of expression values
    percentile_threshold: central percentage to keep (e.g., 95 means keep middle 95%, 
                         remove bottom 2.5% and top 2.5%)
    
    Returns:
    Boolean mask where True indicates outlier
    """
    if len(expression_data) == 0:
        return np.array([], dtype=bool)
    
    # Calculate symmetric bounds
    # For 95%, remove bottom 2.5% and top 2.5%
    lower_percentile = (100 - percentile_threshold) / 2
    upper_percentile = 100 - lower_percentile
    
    lower_bound = np.percentile(expression_data, lower_percentile)
    upper_bound = np.percentile(expression_data, upper_percentile)
    
    # Identify outliers (values below lower bound OR above upper bound)
    outlier_mask = (expression_data < lower_bound) | (expression_data > upper_bound)
    
    return outlier_mask

print(f"Updated outlier detection function: symmetric percentile removal")
print(f"For 95% threshold: removes bottom 2.5% and top 2.5% of values")
print(f"For 99% threshold: removes bottom 0.5% and top 0.5% of values")


Updated outlier detection function: symmetric percentile removal
For 95% threshold: removes bottom 2.5% and top 2.5% of values
For 99% threshold: removes bottom 0.5% and top 0.5% of values


In [12]:
# Updated outlier detection function for symmetric percentile removal
def detect_outliers_percentile(expression_data, percentile_threshold=95):
    """
    Detect outliers using symmetric percentile-based method
    
    Parameters:
    expression_data: array of expression values
    percentile_threshold: central percentage to keep (e.g., 95 means keep middle 95%, 
                         remove bottom 2.5% and top 2.5%)
    
    Returns:
    Boolean mask where True indicates outlier
    """
    if len(expression_data) == 0:
        return np.array([], dtype=bool)
    
    # Calculate symmetric bounds
    # For 95%, remove bottom 2.5% and top 2.5%
    lower_percentile = (100 - percentile_threshold) / 2
    upper_percentile = 100 - lower_percentile
    
    lower_bound = np.percentile(expression_data, lower_percentile)
    upper_bound = np.percentile(expression_data, upper_percentile)
    
    # Identify outliers (values below lower bound OR above upper bound)
    outlier_mask = (expression_data < lower_bound) | (expression_data > upper_bound)
    
    return outlier_mask

print(f"Updated outlier detection function: symmetric percentile removal")
print(f"For 95% threshold: removes bottom 2.5% and top 2.5% of values")
print(f"For 99% threshold: removes bottom 0.5% and top 0.5% of values")


Updated outlier detection function: symmetric percentile removal
For 95% threshold: removes bottom 2.5% and top 2.5% of values
For 99% threshold: removes bottom 0.5% and top 0.5% of values


In [13]:
# Updated outlier detection function for symmetric percentile removal
def detect_outliers_percentile(expression_data, percentile_threshold=95):
    """
    Detect outliers using symmetric percentile-based method
    
    Parameters:
    expression_data: array of expression values
    percentile_threshold: central percentage to keep (e.g., 95 means keep middle 95%, 
                         remove bottom 2.5% and top 2.5%)
    
    Returns:
    Boolean mask where True indicates outlier
    """
    if len(expression_data) == 0:
        return np.array([], dtype=bool)
    
    # Calculate symmetric bounds
    # For 95%, remove bottom 2.5% and top 2.5%
    lower_percentile = (100 - percentile_threshold) / 2
    upper_percentile = 100 - lower_percentile
    
    lower_bound = np.percentile(expression_data, lower_percentile)
    upper_bound = np.percentile(expression_data, upper_percentile)
    
    # Identify outliers (values below lower bound OR above upper bound)
    outlier_mask = (expression_data < lower_bound) | (expression_data > upper_bound)
    
    return outlier_mask

print(f"Updated outlier detection function: symmetric percentile removal")
print(f"For 95% threshold: removes bottom 2.5% and top 2.5% of values")
print(f"For 99% threshold: removes bottom 0.5% and top 0.5% of values")


Updated outlier detection function: symmetric percentile removal
For 95% threshold: removes bottom 2.5% and top 2.5% of values
For 99% threshold: removes bottom 0.5% and top 0.5% of values


In [14]:
def create_volcano_plot(df, title="Volcano Plot: Treatment vs Control", pval_threshold=PVAL_THRESHOLD, logFC_threshold=LOGFC_THRESHOLD):
    """
    Create interactive volcano plot using plotly with configurable significance thresholds
    
    Parameters:
    df: DataFrame with gene, logFC, p_val, -log10_pval columns
    title: plot title
    pval_threshold: p-value significance threshold (default: 0.01)
    logFC_threshold: log2 fold change threshold (default: 10)
    
    Returns:
    plotly figure
    """
    
    # Create significance categories
    df = df.copy()
    df['significance'] = 'Not Significant'
    
    # Significant genes (p < threshold and |logFC| > threshold)
    significant_mask = (df['p_val'] < pval_threshold) & (df['abs_logFC'] > logFC_threshold)
    df.loc[significant_mask, 'significance'] = 'Significant'
    
    # Upregulated (logFC > threshold, p < threshold)
    upregulated_mask = (df['p_val'] < pval_threshold) & (df['logFC'] > logFC_threshold)
    df.loc[upregulated_mask, 'significance'] = 'Upregulated'
    
    # Downregulated (logFC < -threshold, p < threshold)
    downregulated_mask = (df['p_val'] < pval_threshold) & (df['logFC'] < -logFC_threshold)
    df.loc[downregulated_mask, 'significance'] = 'Downregulated'
    
    # Color mapping
    color_map = {
        'Not Significant': 'lightgray',
        'Significant': 'orange', 
        'Upregulated': 'red',
        'Downregulated': 'blue'
    }
    
    # Create the plot
    fig = px.scatter(
        df,
        x='logFC',
        y='-log10_pval',
        color='significance',
        color_discrete_map=color_map,
        hover_data=['gene', 'p_val'],
        title=f"{title}<br><sub>Thresholds: p < {pval_threshold}, |log2FC| > {logFC_threshold}</sub>",
        labels={
            'logFC': 'Log2 Fold Change (Treatment vs Control)',
            '-log10_pval': '-Log10(P-value)',
            'significance': 'Significance'
        },
        width=800,
        height=600
    )
    
    # Add significance threshold lines
    fig.add_hline(y=-np.log10(pval_threshold), line_dash="dash", line_color="black", 
                  annotation_text=f"p-value = {pval_threshold}")
    fig.add_vline(x=logFC_threshold, line_dash="dash", line_color="black", 
                  annotation_text=f"log2FC = {logFC_threshold}")
    fig.add_vline(x=-logFC_threshold, line_dash="dash", line_color="black", 
                  annotation_text=f"log2FC = -{logFC_threshold}")
    
    # Update layout
    fig.update_layout(
        template='plotly_white',
        showlegend=True,
        title_x=0.5,
        font=dict(size=12)
    )
    
    # Count genes in each category
    counts = df['significance'].value_counts()
    print(f"📊 Gene Categories (p < {pval_threshold}, |log2FC| > {logFC_threshold}):")
    for category, count in counts.items():
        print(f"  {category}: {count:,}")
    
    return fig

# Volcano plot function will be used after running the analysis
print("Volcano plot function ready - will be used after running differential expression analysis")

def create_result_dir(result_dir, backup_old=False):
    """
    Delete existing result directory and create a fresh one
    
    Parameters:
    result_dir: path to result directory
    backup_old: if True, rename old directory instead of deleting
    """
    
    # Handle existing result directory
    if os.path.exists(result_dir):
        if backup_old:
            # Create backup with timestamp
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            backup_dir = f"{result_dir}_backup_{timestamp}"
            shutil.move(result_dir, backup_dir)
            print(f"📦 Backed up existing results to: {os.path.abspath(backup_dir)}")
        else:
            # Count existing files before deletion
            try:
                existing_files = len([f for f in os.listdir(result_dir) if os.path.isfile(os.path.join(result_dir, f))])
                shutil.rmtree(result_dir)
                print(f"🗑️  Deleted existing result directory ({existing_files} files): {os.path.abspath(result_dir)}")
            except Exception as e:
                print(f"⚠️  Warning: Could not delete result directory: {e}")
    else:
        print(f"📁 No existing result directory found")
    
    # Create fresh result directory
    os.makedirs(result_dir, exist_ok=True)
    print(f"✨ Created fresh result directory: {os.path.abspath(result_dir)}")
    
    return result_dir

# Create clean result directory  
result_dir = '../result'
print("=== RESULT DIRECTORY MANAGEMENT ===")

# Clean start: delete old results and create fresh directory
create_result_dir(result_dir, backup_old=False)

# Alternative: backup old results instead of deleting
# create_result_dir(result_dir, backup_old=True)

# Plot will be saved after analysis is run
print("Result directory ready for analysis outputs")


Volcano plot function ready - will be used after running differential expression analysis
=== RESULT DIRECTORY MANAGEMENT ===
🗑️  Deleted existing result directory (5 files): /Users/jmsung/projects/bio-ml/diff_exp/result
✨ Created fresh result directory: /Users/jmsung/projects/bio-ml/diff_exp/result
Result directory ready for analysis outputs


In [15]:
# Identify significantly differently expressed genes and create rankings

def identify_significant_genes(df, pval_threshold=PVAL_THRESHOLD, logFC_threshold=LOGFC_THRESHOLD):
    """
    Identify and rank significantly differently expressed genes
    
    Parameters:
    df: DataFrame with differential expression results
    pval_threshold: p-value cutoff for significance
    logFC_threshold: log fold change cutoff for significance
    
    Returns:
    Dictionary with upregulated, downregulated, and all significant genes
    """
    
    # Filter for significant genes
    significant_genes = df[
        (df['p_val'] < pval_threshold) & 
        (df['abs_logFC'] > logFC_threshold)
    ].copy()
    
    # Separate upregulated and downregulated
    upregulated = significant_genes[significant_genes['logFC'] > 0].copy()
    downregulated = significant_genes[significant_genes['logFC'] < 0].copy()
    
    # Sort upregulated by logFC (descending)
    upregulated = upregulated.sort_values('logFC', ascending=False)
    
    # Sort downregulated by logFC (ascending, most negative first)
    downregulated = downregulated.sort_values('logFC', ascending=True)
    
    # All significant genes sorted by significance (p-value then abs_logFC)
    all_significant = significant_genes.sort_values(['p_val', 'abs_logFC'], ascending=[True, False])
    
    return {
        'all_significant': all_significant,
        'upregulated': upregulated,
        'downregulated': downregulated
    }

# This function will be used after running the analysis
print("Significant genes identification function ready")


Significant genes identification function ready


In [16]:
# Simplified differential expression analysis using median (robust to outliers)
def calculate_differential_expression(adata, group_col='group', treatment_name='treatment', control_name='control'):
    """
    Calculate differential expression using median-based fold change (no outlier removal needed)
    
    Parameters:
    adata: AnnData object
    group_col: column name for grouping (default: 'group')
    treatment_name: name of treatment group (default: 'treatment')  
    control_name: name of control group (default: 'control')
    
    Returns:
    DataFrame with columns: gene, logFC, p_val, abs_logFC, -log10_pval
    """
    
    print("🔬 Starting median-based differential expression analysis...")
    print(f"📊 Dataset: {adata.n_obs} cells × {adata.n_vars} genes")
    print(f"🎯 Comparison: {treatment_name} vs {control_name}")
    print("📈 Using MEDIAN for fold change calculation (naturally robust to outliers)")
    print("🚫 No outlier detection needed - simplified and faster analysis")
    
    # Get group masks
    treatment_mask = adata.obs[group_col] == treatment_name
    control_mask = adata.obs[group_col] == control_name
    
    print(f"👥 Group sizes: {treatment_name}={treatment_mask.sum()}, {control_name}={control_mask.sum()}")
    
    # Initialize results storage
    results = []
    total_genes = adata.n_vars
    
    print(f"🧬 Processing {total_genes:,} genes...")
    
    # Process each gene
    for gene_idx in range(total_genes):
        gene_name = adata.var_names[gene_idx]
        
        # Extract expression data for this gene
        if hasattr(adata.X, 'toarray'):
            gene_expr = adata.X[:, gene_idx].toarray().flatten()
        else:
            gene_expr = adata.X[:, gene_idx].flatten()
        
        # Split by group
        treatment_gene_expr = gene_expr[treatment_mask]
        control_gene_expr = gene_expr[control_mask]
        
        # Skip genes with insufficient data
        if len(treatment_gene_expr) < 3 or len(control_gene_expr) < 3:
            continue
        
        # Calculate medians (robust to outliers)
        treatment_median = np.median(treatment_gene_expr)
        control_median = np.median(control_gene_expr)
        
        # Calculate log2 fold change using medians (add small epsilon to avoid log(0))
        epsilon = 1e-6
        log_fc = np.log2((treatment_median + epsilon) / (control_median + epsilon))
        
        # Perform Mann-Whitney U test (non-parametric, robust to outliers)
        try:
            # Use alternative='two-sided' for two-tailed test
            u_stat, p_val = mannwhitneyu(treatment_gene_expr, control_gene_expr, alternative='two-sided')
            
            # Handle potential NaN/infinite p-values
            if np.isnan(p_val) or np.isinf(p_val):
                p_val = 1.0
                
        except Exception as e:
            # If test fails, set p-value to 1 (not significant)
            p_val = 1.0
        
        # Store results
        results.append({
            'gene': gene_name,
            'logFC': log_fc,
            'p_val': p_val,
            'treatment_median': treatment_median,
            'control_median': control_median,
            'treatment_n': len(treatment_gene_expr),
            'control_n': len(control_gene_expr)
        })
        
        # Progress indicator
        if (gene_idx + 1) % 5000 == 0:
            print(f"   Processed {gene_idx + 1:,}/{total_genes:,} genes ({(gene_idx + 1)/total_genes*100:.1f}%)")
    
    # Convert to DataFrame
    df = pd.DataFrame(results)
    
    # Add derived columns
    df['abs_logFC'] = np.abs(df['logFC'])
    df['-log10_pval'] = -np.log10(df['p_val'].clip(lower=1e-300))  # Clip to avoid log(0)
    
    # Sort by p-value, then by absolute log fold change
    df = df.sort_values(['p_val', 'abs_logFC'], ascending=[True, False]).reset_index(drop=True)
    
    # Print summary statistics
    print(f"\n📈 Analysis Complete!")
    print(f"   • Genes analyzed: {len(df):,}")
    print(f"   • Statistical test: Mann-Whitney U (non-parametric)")
    print(f"   • Fold change: Log2(median_treatment / median_control)")
    print(f"   • No outlier removal needed - median is naturally robust")
    
    return df

print("✅ Median-based differential expression analysis function ready")


✅ Median-based differential expression analysis function ready


In [17]:
# Clean volcano plot function using configured constants
def create_volcano_plot(df, title="Volcano Plot: Treatment vs Control", 
                             pval_threshold=PVAL_THRESHOLD, logFC_threshold=LOGFC_THRESHOLD):
    """
    Create interactive volcano plot using plotly with configured default thresholds
    
    Parameters:
    df: DataFrame with gene, logFC, p_val, -log10_pval columns
    title: plot title
    pval_threshold: p-value significance threshold (default: PVAL_THRESHOLD)
    logFC_threshold: log2 fold change threshold (default: LOGFC_THRESHOLD)
    
    Returns:
    plotly figure
    """
    
    # Create significance categories
    df = df.copy()
    df['significance'] = 'Not Significant'
    
    # Significant genes (p < threshold and |logFC| > threshold)
    significant_mask = (df['p_val'] < pval_threshold) & (df['abs_logFC'] > logFC_threshold)
    df.loc[significant_mask, 'significance'] = 'Significant'
    
    # Upregulated (logFC > threshold, p < threshold)
    upregulated_mask = (df['p_val'] < pval_threshold) & (df['logFC'] > logFC_threshold)
    df.loc[upregulated_mask, 'significance'] = 'Upregulated'
    
    # Downregulated (logFC < -threshold, p < threshold)
    downregulated_mask = (df['p_val'] < pval_threshold) & (df['logFC'] < -logFC_threshold)
    df.loc[downregulated_mask, 'significance'] = 'Downregulated'
    
    # Color mapping
    color_map = {
        'Not Significant': 'lightgray',
        'Significant': 'orange', 
        'Upregulated': 'red',
        'Downregulated': 'blue'
    }
    
    # Create the plot
    fig = px.scatter(
        df,
        x='logFC',
        y='-log10_pval',
        color='significance',
        color_discrete_map=color_map,
        hover_data=['gene', 'p_val'],
        title=f"{title}<br><sub>Median-based: p < {pval_threshold}, |log2FC| > {logFC_threshold}</sub>",
        labels={
            'logFC': 'Log2 Fold Change (Median Treatment vs Control)',
            '-log10_pval': '-Log10(P-value)',
            'significance': 'Significance'
        },
        width=800,
        height=600
    )
    
    # Add significance threshold lines
    fig.add_hline(y=-np.log10(pval_threshold), line_dash="dash", line_color="black", 
                  annotation_text=f"p-value = {pval_threshold}")
    fig.add_vline(x=logFC_threshold, line_dash="dash", line_color="black", 
                  annotation_text=f"log2FC = {logFC_threshold}")
    fig.add_vline(x=-logFC_threshold, line_dash="dash", line_color="black", 
                  annotation_text=f"log2FC = -{logFC_threshold}")
    
    # Update layout
    fig.update_layout(
        template='plotly_white',
        showlegend=True,
        title_x=0.5,
        font=dict(size=12)
    )
    
    # Count genes in each category
    counts = df['significance'].value_counts()
    print(f"📊 Gene Categories (median-based, p < {pval_threshold}, |log2FC| > {logFC_threshold}):")
    for category, count in counts.items():
        print(f"  {category}: {count:,}")
    
    return fig

print("✅ Volcano plot function ready - uses configured constants")


✅ Volcano plot function ready - uses configured constants


In [18]:
print("🚀 Analysis Usage:")
print()
print("# Step 1: Run differential expression analysis")
print("df = calculate_differential_expression(adata)")
print()
print("# Step 2: Create volcano plot") 
print("fig = create_volcano_plot(df)")
print()
print("# Step 3: Identify significant genes")
print("sig_genes = identify_significant_genes(df)")
print()
print(f"📊 Current settings: p < {PVAL_THRESHOLD}, |log2FC| > {LOGFC_THRESHOLD}")
print("📈 Median-based fold change (no outlier detection needed)")
print("🧪 Mann-Whitney U test (non-parametric)")
print()
print("💡 To change thresholds: modify constants at top and re-run")


🚀 Analysis Usage:

# Step 1: Run differential expression analysis
df = calculate_differential_expression(adata)

# Step 2: Create volcano plot
fig = create_volcano_plot(df)

# Step 3: Identify significant genes
sig_genes = identify_significant_genes(df)

📊 Current settings: p < 0.01, |log2FC| > 10
📈 Median-based fold change (no outlier detection needed)
🧪 Mann-Whitney U test (non-parametric)

💡 To change thresholds: modify constants at top and re-run


In [19]:
# Export results as requested

def export_results(results_df, significant_results, result_dir):
    """
    Export differential expression results to CSV files
    
    Parameters:
    results_df: Complete differential expression results
    significant_results: Dictionary with significant gene results
    result_dir: Directory to save results
    """
    
    # 1. Save complete results DataFrame (gene, logFC, p_val)
    final_results = results_df[['gene', 'logFC', 'p_val']].copy()
    results_path = os.path.join(result_dir, 'differential_expression_results.csv')
    final_results.to_csv(results_path, index=False)
    print(f"✅ Complete results saved to: {results_path}")
    print(f"   Contains {len(final_results):,} genes with gene, logFC, p_val columns")
    
    # 2. Save top 20 significant genes
    if len(significant_results['all_significant']) > 0:
        top_20 = significant_results['all_significant'].head(20)[['gene', 'logFC', 'p_val']].copy()
        top_20_path = os.path.join(result_dir, 'top_20_significant_genes.csv')
        top_20.to_csv(top_20_path, index=False)
        print(f"✅ Top 20 significant genes saved to: {top_20_path}")
    else:
        print("⚠️  No significant genes found to export")
    
    # 3. Save upregulated genes
    if len(significant_results['upregulated']) > 0:
        upregulated_path = os.path.join(result_dir, 'upregulated_genes.csv')
        significant_results['upregulated'][['gene', 'logFC', 'p_val']].to_csv(upregulated_path, index=False)
        print(f"✅ Upregulated genes saved to: {upregulated_path}")
    
    # 4. Save downregulated genes  
    if len(significant_results['downregulated']) > 0:
        downregulated_path = os.path.join(result_dir, 'downregulated_genes.csv')
        significant_results['downregulated'][['gene', 'logFC', 'p_val']].to_csv(downregulated_path, index=False)
        print(f"✅ Downregulated genes saved to: {downregulated_path}")
    
    # 5. Save summary statistics
    summary_stats = {
        'total_genes_analyzed': len(results_df),
        'significant_genes': len(significant_results['all_significant']),
        'upregulated_genes': len(significant_results['upregulated']),
        'downregulated_genes': len(significant_results['downregulated']),
        'significance_threshold_pval': 0.05,
        'significance_threshold_logFC': 1.0
    }
    
    summary_df = pd.DataFrame([summary_stats])
    summary_path = os.path.join(result_dir, 'analysis_summary.csv')
    summary_df.to_csv(summary_path, index=False)
    print(f"✅ Analysis summary saved to: {summary_path}")
    
    return {
        'results_path': results_path,
        'top_20_path': top_20_path if len(significant_results['all_significant']) > 0 else None,
        'summary_path': summary_path
    }

# Export results function - will be used after analysis is run
print("Export results function ready - will be used after analysis completes")


Export results function ready - will be used after analysis completes


In [20]:
# ===============================
# RUN COMPLETE ANALYSIS PIPELINE
# ===============================

if 'adata' in locals() and adata is not None:
    print("🚀 Running complete differential expression analysis...")
    
    # Step 1: Run differential expression analysis
    print("\n📊 Step 1: Calculating differential expression...")
    diff_expr_results = calculate_differential_expression(adata)
    
    # Step 2: Identify significant genes
    print("\n🎯 Step 2: Identifying significant genes...")
    significant_results = identify_significant_genes(diff_expr_results)
    
    print(f"\n=== RESULTS SUMMARY ===")
    print(f"Total genes analyzed: {len(diff_expr_results):,}")
    print(f"Significant genes: {len(significant_results['all_significant']):,}")
    print(f"  • Upregulated: {len(significant_results['upregulated']):,}")
    print(f"  • Downregulated: {len(significant_results['downregulated']):,}")
    
    # Step 3: Create volcano plot
    print("\n📈 Step 3: Creating volcano plot...")
    volcano_fig = create_volcano_plot(diff_expr_results, title="Differential Expression Analysis")
    
    # Step 4: Create results directory and save outputs
    print("\n💾 Step 4: Saving results...")
    result_dir = '../result'
    create_result_dir(result_dir, backup_old=False)
    
    # Save volcano plot
    plot_path = os.path.join(result_dir, 'volcano_plot.html')
    volcano_fig.write_html(plot_path)
    print(f"✅ Volcano plot saved: {plot_path}")
    
    # Export results
    export_paths = export_results(diff_expr_results, significant_results, result_dir)
    
    # Display top results
    if len(significant_results['all_significant']) > 0:
        print(f"\n🔝 TOP 10 SIGNIFICANT GENES:")
        top_10 = significant_results['all_significant'].head(10)
        print(top_10[['gene', 'logFC', 'p_val']].to_string(index=False))
    else:
        print("\n⚠️  No significant genes found with current thresholds")
        print(f"   Consider lowering thresholds: p < {PVAL_THRESHOLD}, |logFC| > {LOGFC_THRESHOLD}")
    
    print(f"\n✅ ANALYSIS COMPLETE!")
    print(f"📁 Results saved to: {os.path.abspath(result_dir)}")
    
else:
    print("⚠️  Data not loaded. Please run the data loading cells first.")
    print("   The analysis will run automatically once 'adata' is available.")


🚀 Running complete differential expression analysis...

📊 Step 1: Calculating differential expression...
🔬 Starting median-based differential expression analysis...
📊 Dataset: 5025 cells × 61198 genes
🎯 Comparison: treatment vs control
📈 Using MEDIAN for fold change calculation (naturally robust to outliers)
🚫 No outlier detection needed - simplified and faster analysis
👥 Group sizes: treatment=2543, control=2482
🧬 Processing 61,198 genes...
   Processed 5,000/61,198 genes (8.2%)
   Processed 10,000/61,198 genes (16.3%)
   Processed 15,000/61,198 genes (24.5%)
   Processed 20,000/61,198 genes (32.7%)
   Processed 25,000/61,198 genes (40.9%)
   Processed 30,000/61,198 genes (49.0%)
   Processed 35,000/61,198 genes (57.2%)
   Processed 40,000/61,198 genes (65.4%)
   Processed 45,000/61,198 genes (73.5%)
   Processed 50,000/61,198 genes (81.7%)
   Processed 55,000/61,198 genes (89.9%)
   Processed 60,000/61,198 genes (98.0%)

📈 Analysis Complete!
   • Genes analyzed: 61,198
   • Statistic