# MAGeCK GWAS/KO Screen Analysis

Analysis of pooled CRISPR KO screens across multiple donors using MAGeCK. Inputs include `sample_info.csv` for donor/bin mapping, `CSPA.csv` for surface protein annotations, and `brunello_library.txt` as the sgRNA library reference.

**Reproducibility & usage**
- Run cells top-to-bottom.
- Python ≥3.10. Install dependencies from `environment.yml` or `requirements.txt`.
- Seeds fixed for numpy/python; figure saving paths are centralized.

In [None]:
# ---- Standard imports & config ----
import os, sys
from pathlib import Path
import json, math, random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Reproducibility
random.seed(1337)
np.random.seed(1337)

# Project paths (edit as needed)
PROJ_ROOT = Path("..").resolve().parent if (Path.cwd()/".here").exists() else Path(".").resolve()
DATA_DIR = PROJ_ROOT / "data"
RESULTS_DIR = PROJ_ROOT / "results"
FIG_DIR = PROJ_ROOT / "figures"
for d in [DATA_DIR, RESULTS_DIR, FIG_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Matplotlib defaults for publication
plt.rcParams.update({
    "figure.dpi": 150,
    "savefig.bbox": "tight",
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.grid": True,
    "grid.alpha": 0.2,
})

def savefig(name, ext="png"):
    out = FIG_DIR / f"{name}.{ext}"
    plt.savefig(out)
    print(f"[saved] {out}")

In [None]:
## this notebook is for analyzing genomic KO screen
## using MAGeCK
## plotting top hits
from scipy.stats import spearmanr
import matplotlib.pyplot as plt

In [None]:
#pip install gseapy

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

base_dir = '/data/JE_misc/GW_screen/'
sample_info_file = os.path.join(base_dir, 'sample_info.csv')

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Function to process a single count file
def process_count_file(file_path):
    try:
        df = pd.read_csv(file_path, sep='\t')
        df = df[df['Gene'] != 'Non-Targeting Control']
        gene_counts = df.groupby('Gene')[['control', 'treatment']].sum()
        return gene_counts
    except Exception as e:
        print(f"Failed to process {file_path}: {e}")
        return None

# Base directory
base_dir = '/data/JE_misc/GW_screen'

# Collect gene-level counts across all .count.txt files
all_gene_counts = []

for root, dirs, files in os.walk(base_dir):
    for file in files:
        if file.endswith(".count.txt"):
            full_path = os.path.join(root, file)
            print(f"Processing file: {full_path}")
            gene_counts = process_count_file(full_path)
            if gene_counts is not None:
                all_gene_counts.append((file, gene_counts))

# Plot distribution for each count file
fig, axs = plt.subplots(nrows=4, ncols=4, figsize=(16, 16))
axs = axs.flatten()

for i, (filename, gene_counts) in enumerate(all_gene_counts[:16]):
    ax = axs[i]
    ax.hist(gene_counts['control'], bins=50, alpha=0.7, label='control', log=True)
    ax.hist(gene_counts['treatment'], bins=50, alpha=0.7, label='treatment', log=True)
    ax.set_title(filename, fontsize=8)
    ax.set_xscale('log')
    ax.set_xlabel('Total Gene Count')
    ax.set_ylabel('Frequency')
    ax.legend(fontsize=6)

plt.tight_layout()
plt.show()


In [None]:
import glob

base_path = "/data/JE_misc/GW_screen"
conditions = ["Ark312", "Ark315", "AAV6"]
replicates = ["rep1", "rep2"]
donors=['Donor20','Donor21']

# Load gene_summary.txt
def load_gene_summary(donor, condition, rep):
    path_pattern = os.path.join(base_path, donor, condition, rep, "*gene_summary.txt")
    files = glob.glob(path_pattern)
    if not files:
        raise FileNotFoundError(f"No file found for {donor}, {condition}, {rep}")
    df = pd.read_csv(files[0], sep="\t")
    df = df.rename(columns={"id": "gene", "neg|lfc": "neg_lfc"})
    return df[["gene", "neg_lfc"]].set_index("gene")

# Set up panel plot: 2 rows (replicates) × 3 columns (conditions)
fig, axes = plt.subplots(2, 3, figsize=(15, 8), sharex=True, sharey=True)

for i, rep in enumerate(replicates):
    for j, condition in enumerate(conditions):
        ax = axes[i, j]
        df_d20 = load_gene_summary("Donor20", condition, rep)
        df_d21 = load_gene_summary("Donor21", condition, rep)
        merged = df_d20.join(df_d21, lsuffix="_Donor20", rsuffix="_Donor21").dropna()
        corr = merged.corr().iloc[0, 1]

        sns.scatterplot(
            x="neg_lfc_Donor20",
            y="neg_lfc_Donor21",
            data=merged,
            s=15,
            ax=ax
        )
        ax.axhline(0, color='gray', linestyle='--', linewidth=1)
        ax.axvline(0, color='gray', linestyle='--', linewidth=1)
        ax.set_title(f"{condition} | {rep}\nr = {corr:.2f}")
        if i == 1:
            ax.set_xlabel("Donor20 LFC")
        else:
            ax.set_xlabel("")
        if j == 0:
            ax.set_ylabel("Donor21 LFC")
        else:
            ax.set_ylabel("")

plt.tight_layout()
plt.show()


In [None]:
# Function to load gene_summary.txt
def load_gene_summary(donor, condition, rep):
    path_pattern = os.path.join(base_path, donor, condition, rep, "*gene_summary.txt")
    files = glob.glob(path_pattern)
    if not files:
        raise FileNotFoundError(f"No file found for {donor}, {condition}, {rep}")
    df = pd.read_csv(files[0], sep="\t")
    df = df.rename(columns={"id": "gene", "neg|lfc": "neg_lfc"})
    return df[["gene", "neg_lfc"]].set_index("gene")

# Set up 2 (donors) × 3 (conditions) panel
fig, axes = plt.subplots(2, 3, figsize=(15, 8), sharex=True, sharey=True)

for i, donor in enumerate(donors):
    for j, condition in enumerate(conditions):
        ax = axes[i, j]
        df_rep1 = load_gene_summary(donor, condition, "rep1")
        df_rep2 = load_gene_summary(donor, condition, "rep2")
        merged = df_rep1.join(df_rep2, lsuffix="_rep1", rsuffix="_rep2").dropna()
        corr = merged.corr().iloc[0, 1]

        sns.scatterplot(
            x="neg_lfc_rep1",
            y="neg_lfc_rep2",
            data=merged,
            s=15,
            ax=ax
        )
        ax.axhline(0, color='gray', linestyle='--', linewidth=1)
        ax.axvline(0, color='gray', linestyle='--', linewidth=1)
        ax.set_title(f"{condition} | {donor}\nr = {corr:.2f}")
        if i == 1:
            ax.set_xlabel("Rep1 LFC")
        else:
            ax.set_xlabel("")
        if j == 0:
            ax.set_ylabel("Rep2 LFC")
        else:
            ax.set_ylabel("")

plt.tight_layout()
plt.show()

In [None]:
import pandas as pd
import numpy as np

# Load the gene summary data
#df = pd.read_csv('brunello_analysis.gene_summary.txt', sep='\t')
df = pd.read_csv('Donor20/Ark312/rep2/mageck_Donor20_Ark312_rep2_test.gene_summary.txt', sep='\t')

# Initialize lists to store the selected values
selected_pvals = []
selected_fdrs = []
selected_lfcs = []
selected_directions = []

# Iterate through each row to select the most significant result
for index, row in df.iterrows():
    neg_pval = row['neg|p-value']
    pos_pval = row['pos|p-value']
    lfc = row['pos|lfc']
    
    if lfc < 0:
        selected_pvals.append(neg_pval)
        selected_fdrs.append(row['neg|fdr'])
        selected_lfcs.append(row['neg|lfc'])
        selected_directions.append('Negative')
    else:
        selected_pvals.append(pos_pval)
        selected_fdrs.append(row['pos|fdr'])
        selected_lfcs.append(row['pos|lfc'])
        selected_directions.append('Positive')

# Add the selected values to the DataFrame
df['Selected_pval'] = selected_pvals
df['Selected_lfc'] = selected_lfcs
df['Direction'] = selected_directions

# Calculate -log10(FDR) for plotting
df['minus_log10_pval'] = -np.log10(df['Selected_pval'])


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("white")

# Define significance thresholds
pval_threshold = 0.05
lfc_threshold = 0.5

# Categorize genes based on thresholds
def categorize_gene(row):
    if row['Selected_pval'] < pval_threshold:
        if row['Selected_lfc'] > lfc_threshold:
            return 'Enriched'
        elif row['Selected_lfc'] < -lfc_threshold:
            return 'Depleted'
    return 'Not Significant'

df['Category'] = df.apply(categorize_gene, axis=1)

# Create color palette
palette = {'Enriched': 'red', 'Depleted': 'blue', 'Not Significant': 'grey'}

# Create the plot
plt.figure(figsize=(10, 8))
sns.scatterplot(data=df, x='Selected_lfc', y='minus_log10_pval', hue='Category', palette=palette, edgecolor='none', alpha=0.7)

# Add threshold lines
plt.axhline(-np.log10(pval_threshold), linestyle='--', color='black', linewidth=1)
plt.axvline(lfc_threshold, linestyle='--', color='black', linewidth=1)
plt.axvline(-lfc_threshold, linestyle='--', color='black', linewidth=1)

# Annotate top genes
top_genes = df[df['Category'] != 'Not Significant'].sort_values('Selected_pval').head(50)
for _, row in top_genes.iterrows():
    plt.text(row['Selected_lfc'], row['minus_log10_pval'], row['id'], fontsize=8)

# Customize plot
plt.title('Volcano Plot of MAGeCK Results')
plt.xlabel('Log2 Fold Change')
plt.ylabel('-Log10(pval)')
plt.legend(title='Category')
plt.tight_layout()
plt.show()


In [None]:
#trying to rerun with replicate count files and just adding the two replicates together and redoing mageck test
import subprocess
# Define all donors and conditions
donors = ["Donor20", "Donor21"]
conditions = ["Ark312", "Ark315", "AAV6"]
base_dir = "/data/JE_misc/GW_screen"

# Helper function to load and index a count file
def load_count_table(path):
    df = pd.read_csv(path, sep='\t')
    df = df.set_index('sgRNA')
    return df

# Store commands for each pooled test
mageck_commands = []

# Loop through each donor and condition
for donor in donors:
    for condition in conditions:
        rep1_path = os.path.join(base_dir, donor, condition, "rep1", f"mageck_{donor}_{condition}_rep1.count.txt")
        rep2_path = os.path.join(base_dir, donor, condition, "rep2", f"mageck_{donor}_{condition}_rep2.count.txt")
        
        # Skip if files don't exist
        if not os.path.exists(rep1_path) or not os.path.exists(rep2_path):
            print(f"Missing files for {donor} {condition}")
            continue
        
        # Load replicates
        rep1 = load_count_table(rep1_path)
        rep2 = load_count_table(rep2_path)

        # Sum counts
        pooled = rep1[['Gene', 'control', 'treatment']].copy()
        pooled['control'] += rep2['control']
        pooled['treatment'] += rep2['treatment']
        
        # Save pooled counts
        pooled_dir = os.path.join(base_dir, donor, condition, "pooled")
        os.makedirs(pooled_dir, exist_ok=True)
        pooled_count_path = os.path.join(pooled_dir, f"mageck_{donor}_{condition}_pooled.count.txt")
        pooled.reset_index().to_csv(pooled_count_path, sep='\t', index=False)

        output_prefix = f"mageck_{donor}_{condition}_pooled_test"
        output_prefix_full = os.path.join(pooled_dir, output_prefix)

        print(f"Running MAGeCK test for {donor} {condition}...")

        subprocess.run([
        "mageck", "test",
        "-k", pooled_count_path,
        "-t", "treatment",
        "-c", "control",
        "-n", output_prefix,
        "--norm-method", "median",
        "--output-prefix", output_prefix_full
        ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)

        print(f"Finished MAGeCK test: {output_prefix_full}")


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

base_path = "/data/JE_misc/GW_screen"
conditions = ["Ark312", "Ark315", "AAV6"]

def load_pooled_gene_summary(donor, condition):
    path = os.path.join(
        base_path, donor, condition, "pooled", f"mageck_{donor}_{condition}_pooled_test.gene_summary.txt"
    )
    if not os.path.exists(path):
        raise FileNotFoundError(f"No pooled gene_summary.txt for {donor}, {condition}")
    df = pd.read_csv(path, sep="\t")
    df = df.rename(columns={"id": "gene", "neg|lfc": "neg_lfc"})
    return df[["gene", "neg_lfc"]].set_index("gene")

# Panel plot: 1 row × 3 columns
fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharex=True, sharey=True)

for j, condition in enumerate(conditions):
    ax = axes[j]
    df_d20 = load_pooled_gene_summary("Donor20", condition)
    df_d21 = load_pooled_gene_summary("Donor21", condition)
    merged = df_d20.join(df_d21, lsuffix="_Donor20", rsuffix="_Donor21").dropna()

    # Compute std deviations
    sd_d20 = merged["neg_lfc_Donor20"].std()
    sd_d21 = merged["neg_lfc_Donor21"].std()

    # Compute correlation only on filtered data
    corr = merged[["neg_lfc_Donor20", "neg_lfc_Donor21"]].corr().iloc[0, 1]

    # Plot full data (gray), highlight 4SD-included genes (blue)
    sns.scatterplot(
        data=merged,
        x="neg_lfc_Donor20", y="neg_lfc_Donor21",
        s=15, ax=ax, color="black"
    )
    #sns.scatterplot(
    #    data=filtered,
    #    x="neg_lfc_Donor20", y="neg_lfc_Donor21",
    #    s=15, ax=ax, color="blue", label="≤4 SD"
    #)

    ax.axhline(0, color='gray', linestyle='--', linewidth=1)
    ax.axvline(0, color='gray', linestyle='--', linewidth=1)
    ax.set_title(f"{condition}\nPearson r ={corr:.2f}")
    ax.set_xlabel("Donor20 LFC")
    if j == 0:
        ax.set_ylabel("Donor21 LFC")
    else:
        ax.set_ylabel("")

plt.tight_layout()
plt.savefig("donor_donor_correlation.svg", format="svg")
plt.show()


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

base_path = "/data/JE_misc/GW_screen"
conditions = ["Ark312", "Ark315", "AAV6"]

def load_pooled_gene_summary(donor, condition):
    path = os.path.join(
        base_path, donor, condition, "pooled", f"mageck_{donor}_{condition}_pooled_test.gene_summary.txt"
    )
    if not os.path.exists(path):
        raise FileNotFoundError(f"No pooled gene_summary.txt for {donor}, {condition}")
    df = pd.read_csv(path, sep="\t")
    df = df.rename(columns={"id": "gene", "neg|lfc": "neg_lfc"})
    return df[["gene", "neg_lfc"]].set_index("gene")

# Plot setup
fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharex=True, sharey=True)

for j, condition in enumerate(conditions):
    ax = axes[j]
    df_d20 = load_pooled_gene_summary("Donor20", condition)
    df_d21 = load_pooled_gene_summary("Donor21", condition)
    merged = df_d20.join(df_d21, lsuffix="_Donor20", rsuffix="_Donor21").dropna()

    # Compute std deviations
    sd_d20 = merged["neg_lfc_Donor20"].std()
    sd_d21 = merged["neg_lfc_Donor21"].std()

    # Keep only outliers > 4 SD in either donor
    outliers = merged[
        (merged["neg_lfc_Donor20"].abs() > 4 * sd_d20) |
        (merged["neg_lfc_Donor21"].abs() > 4 * sd_d21)
    ]

    # Compute correlation only on outliers
    corr = outliers[["neg_lfc_Donor20", "neg_lfc_Donor21"]].corr().iloc[0, 1]

    # Plot background all data
    sns.scatterplot(
        data=merged,
        x="neg_lfc_Donor20", y="neg_lfc_Donor21",
        s=15, ax=ax, color="lightgray"
    )
    # Plot outliers
    sns.scatterplot(
        data=outliers,
        x="neg_lfc_Donor20", y="neg_lfc_Donor21",
        s=20, ax=ax, color="red", label=">4 SD"
    )

    # Annotate outlier gene names
    for gene, row in outliers.iterrows():
        ax.text(
            row["neg_lfc_Donor20"], 
            row["neg_lfc_Donor21"], 
            gene, 
            fontsize=7, 
            color="black", 
            alpha=0.7
        )

    ax.axhline(0, color='gray', linestyle='--', linewidth=1)
    ax.axvline(0, color='gray', linestyle='--', linewidth=1)
    ax.set_title(f"{condition}\nPearson r (>4 SD) = {corr:.2f}")
    ax.set_xlabel("Donor20 LFC")
    if j == 0:
        ax.set_ylabel("Donor21 LFC")
    else:
        ax.set_ylabel("")

plt.tight_layout()
plt.savefig("donor_donor_outliers.svg", format="svg")
plt.show()

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

base_path = "/data/JE_misc/GW_screen"
conditions = ["Ark312", "Ark315", "AAV6"]

def load_pooled_gene_summary(donor, condition):
    path = os.path.join(
        base_path, donor, condition, "pooled", f"mageck_{donor}_{condition}_pooled_test.gene_summary.txt"
    )
    if not os.path.exists(path):
        raise FileNotFoundError(f"No pooled gene_summary.txt for {donor}, {condition}")
    df = pd.read_csv(path, sep="\t")
    df = df.rename(columns={"id": "gene", "neg|lfc": "neg_lfc"})
    return df[["gene", "neg_lfc"]].set_index("gene")

# Plot setup
fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharex=True, sharey=True)

for j, condition in enumerate(conditions):
    ax = axes[j]
    df_d20 = load_pooled_gene_summary("Donor20", condition)
    df_d21 = load_pooled_gene_summary("Donor21", condition)
    merged = df_d20.join(df_d21, lsuffix="_Donor20", rsuffix="_Donor21").dropna()

    # Compute std deviations
    sd_d20 = merged["neg_lfc_Donor20"].std()
    sd_d21 = merged["neg_lfc_Donor21"].std()

    # Keep only outliers > 4 SD in either donor
    outliers = merged[
        (merged["neg_lfc_Donor20"].abs() > 6 * sd_d20) |
        (merged["neg_lfc_Donor21"].abs() > 6 * sd_d21)
    ]

    # Compute correlation only on outliers
    corr = outliers[["neg_lfc_Donor20", "neg_lfc_Donor21"]].corr().iloc[0, 1]

    # Plot background all data
    sns.scatterplot(
        data=merged,
        x="neg_lfc_Donor20", y="neg_lfc_Donor21",
        s=15, ax=ax, color="lightgray"
    )
    # Plot outliers
    sns.scatterplot(
        data=outliers,
        x="neg_lfc_Donor20", y="neg_lfc_Donor21",
        s=20, ax=ax, color="red", label=">6 SD"
    )

    # Annotate outlier gene names
    for gene, row in outliers.iterrows():
        ax.text(
            row["neg_lfc_Donor20"], 
            row["neg_lfc_Donor21"], 
            gene, 
            fontsize=7, 
            color="black", 
            alpha=0.7
        )

    ax.axhline(0, color='gray', linestyle='--', linewidth=1)
    ax.axvline(0, color='gray', linestyle='--', linewidth=1)
    ax.set_title(f"{condition}\nPearson r (>6 SD) = {corr:.2f}")
    ax.set_xlabel("Donor20 LFC")
    if j == 0:
        ax.set_ylabel("Donor21 LFC")
    else:
        ax.set_ylabel("")

plt.tight_layout()
#plt.savefig("donor_donor_6SD_outliers.svg", format="svg")
plt.show()

In [None]:
import pandas as pd
import numpy as np

# Load the gene summary data
#df = pd.read_csv('brunello_analysis.gene_summary.txt', sep='\t')
df = pd.read_csv('Donor20/Ark312/rep2/mageck_Donor20_Ark312_rep2_test.gene_summary.txt', sep='\t')

# Initialize lists to store the selected values
selected_pvals = []
selected_fdrs = []
selected_lfcs = []
selected_directions = []

# Iterate through each row to select the most significant result
for index, row in df.iterrows():
    neg_pval = row['neg|p-value']
    pos_pval = row['pos|p-value']
    lfc = row['pos|lfc']
    
    if lfc < 0:
        selected_pvals.append(neg_pval)
        selected_fdrs.append(row['neg|fdr'])
        selected_lfcs.append(row['neg|lfc'])
        selected_directions.append('Negative')
    else:
        selected_pvals.append(pos_pval)
        selected_fdrs.append(row['pos|fdr'])
        selected_lfcs.append(row['pos|lfc'])
        selected_directions.append('Positive')

# Add the selected values to the DataFrame
df['Selected_pval'] = selected_pvals
df['Selected_lfc'] = selected_lfcs
df['Direction'] = selected_directions

# Calculate -log10(FDR) for plotting
df['minus_log10_pval'] = -np.log10(df['Selected_pval'])


sns.set_style("white")

# Define significance thresholds
pval_threshold = 0.05
lfc_threshold = 0.3

# Categorize genes based on thresholds
def categorize_gene(row):
    if row['Selected_pval'] < pval_threshold:
        if row['Selected_lfc'] > lfc_threshold:
            return 'Enriched'
        elif row['Selected_lfc'] < -lfc_threshold:
            return 'Depleted'
    return 'Not Significant'

df['Category'] = df.apply(categorize_gene, axis=1)

# Create color palette
palette = {'Enriched': 'red', 'Depleted': 'blue', 'Not Significant': 'grey'}

# Create the plot
plt.figure(figsize=(10, 8))
sns.scatterplot(data=df, x='Selected_lfc', y='minus_log10_pval', hue='Category', palette=palette, edgecolor='none', alpha=0.7)

# Add threshold lines
plt.axhline(-np.log10(pval_threshold), linestyle='--', color='black', linewidth=1)
plt.axvline(lfc_threshold, linestyle='--', color='black', linewidth=1)
plt.axvline(-lfc_threshold, linestyle='--', color='black', linewidth=1)

# Annotate top genes
top_genes = df[df['Category'] != 'Not Significant'].sort_values('Selected_pval').head(50)
for _, row in top_genes.iterrows():
    plt.text(row['Selected_lfc'], row['minus_log10_pval'], row['id'], fontsize=8)

# Customize plot
plt.title('Volcano Plot of MAGeCK Results')
plt.xlabel('Log2 Fold Change')
plt.ylabel('-Log10(pval)')
plt.legend(title='Category')
plt.tight_layout()
plt.show()


In [None]:
#important code block here to analyze data separately and together

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from glob import glob
import os
from adjustText import adjust_text

# Setup
conditions = ['AAV6', 'Ark315', 'Ark312']
donors = ['Donor20', 'Donor21']
base_path = '/data/JE_misc/GW_screen'
pval_threshold = 0.05
lfc_threshold = 0.5
palette = {
    'Enriched': '#D73027',   # muted deep red
    'Depleted': '#4575B4',   # muted deep blue
    'Not Significant': '#BDBDBD'  # soft gray
}

def process_file(file_path):
    df = pd.read_csv(file_path, sep='\t')
    selected_pvals, selected_lfcs, selected_dirs = [], [], []

    for _, row in df.iterrows():
        pos_lfc = row['pos|lfc']
        if pos_lfc < 0:
            selected_pvals.append(row['neg|p-value'])
            selected_lfcs.append(row['neg|lfc'])
            selected_dirs.append('Negative')
        else:
            selected_pvals.append(row['pos|p-value'])
            selected_lfcs.append(row['pos|lfc'])
            selected_dirs.append('Positive')

    df['Selected_pval'] = selected_pvals
    df['Selected_lfc'] = selected_lfcs
    df['Direction'] = selected_dirs
    df['minus_log10_pval'] = -np.log10(df['Selected_pval'])

    def categorize(row):
        if row['Selected_pval'] < pval_threshold:
            if row['Selected_lfc'] > lfc_threshold:
                return 'Enriched'
            elif row['Selected_lfc'] < -lfc_threshold:
                return 'Depleted'
        return 'Not Significant'

    df['Category'] = df.apply(categorize, axis=1)
    return df[['id', 'Selected_pval', 'Selected_lfc', 'minus_log10_pval', 'Category']]

def label_top_genes(ax, df, top_n=10):
    top_enriched = df[df['Selected_lfc'] > 0].nsmallest(top_n, 'Selected_pval')
    top_depleted = df[df['Selected_lfc'] < 0].nsmallest(top_n, 'Selected_pval')
    top_hits = pd.concat([top_enriched, top_depleted])

    texts = []
    for _, row in top_hits.iterrows():
        texts.append(ax.text(row['Selected_lfc'], row['minus_log10_pval'], row['id'],
                             fontsize=6, ha='center', va='bottom'))

    # Adjust text to prevent overlaps (no arrows)
    adjust_text(texts, ax=ax)

In [None]:
# Step 1: Process all available donor-condition combinations
all_results = {}
for condition in conditions:
    for donor in donors:
        pattern = f"{base_path}/{donor}/{condition}/**/*gene_summary.txt"
        matches = glob(pattern, recursive=True)
        if matches:
            df = process_file(matches[0])  # take first match
            all_results[(condition, donor)] = df



# Step 2: Plot volcano plots for each donor-condition
fig, axes = plt.subplots(len(conditions), len(donors), figsize=(14, 12))
for i, condition in enumerate(conditions):
    for j, donor in enumerate(donors):
        ax = axes[i, j]
        key = (condition, donor)
        if key in all_results:
            df = all_results[key]
            sns.scatterplot(data=df, x='Selected_lfc', y='minus_log10_pval',
                            hue='Category', palette=palette, ax=ax, legend=False, alpha=0.7)
            ax.axhline(-np.log10(pval_threshold), linestyle='--', color='black')
            ax.axvline(lfc_threshold, linestyle='--', color='black')
            ax.axvline(-lfc_threshold, linestyle='--', color='black')
            label_top_genes(ax, df)
            ax.set_title(f'{condition} - {donor}')
        else:
            ax.set_title(f'{condition} - {donor}\n(No file found)')
        ax.set_xlabel('LFC')
        ax.set_ylabel('-log10(pval)')
        
plt.tight_layout()
plt.savefig("volcano_separate.svg", format="svg")
plt.show()

In [None]:
# Extract depleted genes for each condition-donor combination
depleted_genes = {}
for condition in conditions:
    for donor in donors:
        key = (condition, donor)
        if key in all_results:
            df = all_results[key]
            depleted = df[df['Category'] == 'Depleted']['id'].tolist()
            depleted_genes[key] = depleted

# Compare between donors for each condition
comparison_results = {}
for condition in conditions:
    donor1_genes = set(depleted_genes.get((condition, donors[0]), []))
    donor2_genes = set(depleted_genes.get((condition, donors[1]), []))
    
    conserved = donor1_genes.intersection(donor2_genes)
    unique_donor1 = donor1_genes - donor2_genes
    unique_donor2 = donor2_genes - donor1_genes
    
    comparison_results[condition] = {
        'donor1_total': len(donor1_genes),
        'donor2_total': len(donor2_genes),
        'conserved': len(conserved),
        'unique_donor1': len(unique_donor1),
        'unique_donor2': len(unique_donor2),
        'conserved_genes': list(conserved),
        'unique_donor1_genes': list(unique_donor1),
        'unique_donor2_genes': list(unique_donor2)
    }

# Create summary table
summary_data = []
for condition in conditions:
    if condition in comparison_results:
        data = comparison_results[condition]
        total_unique = data['donor1_total'] + data['donor2_total'] - data['conserved']
        conservation_rate = (data['conserved'] / total_unique * 100) if total_unique > 0 else 0
        
        summary_data.append({
            'Condition': condition,
            f'{donors[0]}_Total': data['donor1_total'],
            f'{donors[1]}_Total': data['donor2_total'],
            'Conserved': data['conserved'],
            f'Unique_to_{donors[0]}': data['unique_donor1'],
            f'Unique_to_{donors[1]}': data['unique_donor2'],
            'Conservation_Rate_%': round(conservation_rate, 1)
        })

# Create and display summary table
if summary_data:
    summary_df = pd.DataFrame(summary_data)
    print("DEPLETED GENES COMPARISON SUMMARY:")
    print("="*60)
    print(summary_df.to_string(index=False))
else:
    print("No data available for comparison")

# Export gene lists
print("\n" + "="*60)
print("EXPORTING GENE LISTS")
print("="*60)

for condition in conditions:
    if condition in comparison_results:
        data = comparison_results[condition]
        
        # Export conserved genes
        if data['conserved_genes']:
            filename = f"conserved_depleted_{condition}.txt"
            with open(filename, 'w') as f:
                for gene in sorted(data['conserved_genes']):
                    f.write(f"{gene}\n")
            print(f"Exported {len(data['conserved_genes'])} conserved genes to '{filename}'")
        
        # Export unique genes for each donor
        if data['unique_donor1_genes']:
            filename = f"unique_depleted_{condition}_{donors[0]}.txt"
            with open(filename, 'w') as f:
                for gene in sorted(data['unique_donor1_genes']):
                    f.write(f"{gene}\n")
            print(f"Exported {len(data['unique_donor1_genes'])} unique genes to '{filename}'")
        
        if data['unique_donor2_genes']:
            filename = f"unique_depleted_{condition}_{donors[1]}.txt"
            with open(filename, 'w') as f:
                for gene in sorted(data['unique_donor2_genes']):
                    f.write(f"{gene}\n")
            print(f"Exported {len(data['unique_donor2_genes'])} unique genes to '{filename}'")

# Generate visual comparison
fig, axes = plt.subplots(1, len(conditions), figsize=(12, 5))
if len(conditions) == 1:
    axes = [axes]

for i, condition in enumerate(conditions):
    if condition in comparison_results:
        data = comparison_results[condition]
        
        categories = ['Conserved', f'Unique to\n{donors[0]}', f'Unique to\n{donors[1]}']
        values = [data['conserved'], data['unique_donor1'], data['unique_donor2']]
        colors = ['#2E8B57', '#FF6B6B', '#4ECDC4']
        
        bars = axes[i].bar(categories, values, color=colors, alpha=0.7)
        axes[i].set_title(f'{condition}\nDepleted Genes')
        axes[i].set_ylabel('Number of Genes')
        
        # Add value labels on bars
        for bar, value in zip(bars, values):
            if value > 0:
                axes[i].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
                           str(value), ha='center', va='bottom')

plt.tight_layout()
plt.show()

print("\nAnalysis complete!")

In [None]:
from scipy import stats

# thresholds
pval_threshold = 0.05
lfc_threshold = 0.5

conditions = ['AAV6','Ark312','Ark315']

def fisher_combine_pvalues(p1, p2):
    """
    Combine two p-values using Fisher's method
    """
    # Handle edge cases
    if pd.isna(p1) or pd.isna(p2):
        return np.nan
    if p1 == 0 or p2 == 0:
        return 1e-300  # Very small value to avoid log(0)
    
    # Fisher's method: -2 * sum(ln(p_i)) follows chi-square distribution with 2k df
    chi_squared = -2 * (np.log(p1) + np.log(p2))
    # With 2 p-values, df = 2*2 = 4
    combined_pval = 1 - stats.chi2.cdf(chi_squared, df=4)
    return combined_pval

# Step 3: Donor-averaged volcano plots with Fisher's method
averaged_results = {}
for condition in conditions:
    df1 = all_results.get((condition, 'Donor20'))
    df2 = all_results.get((condition, 'Donor21'))
    if df1 is not None and df2 is not None:
        merged = pd.merge(df1, df2, on='id', suffixes=('_1', '_2'))
        
        merged['Selected_lfc'] = merged[['Selected_lfc_1', 'Selected_lfc_2']].mean(axis=1)
        
        # Combine p-values using Fisher's method
        merged['Selected_pval'] = merged.apply(
            lambda row: fisher_combine_pvalues(row['Selected_pval_1'], row['Selected_pval_2']), 
            axis=1
        )
        
        # Calculate -log10 p-value for plotting
        merged['minus_log10_pval'] = -np.log10(merged['Selected_pval'].clip(lower=1e-300))
        
        def categorize(row):
            if row['Selected_pval'] < pval_threshold:
                if row['Selected_lfc'] > lfc_threshold:
                    return 'Enriched'
                elif row['Selected_lfc'] < -lfc_threshold:
                    return 'Depleted'
            return 'Not Significant'
        
        merged['Category'] = merged.apply(categorize, axis=1)
        averaged_results[condition] = merged

# Save the averaged datasets
output_dir = "averaged_results"
os.makedirs(output_dir, exist_ok=True)

for condition, df in averaged_results.items():
    # Save full dataset
    output_file = os.path.join(output_dir, f"{condition}_averaged_results.csv")
    df.to_csv(output_file, index=False)
    print(f"Saved {condition} averaged results to {output_file}")
    
    # Print summary statistics
    print(f"\n{condition} Summary:")
    print(f"Total genes: {len(df)}")
    print(f"Enriched: {sum(df['Category'] == 'Enriched')}")
    print(f"Depleted: {sum(df['Category'] == 'Depleted')}")
    print(f"Not Significant: {sum(df['Category'] == 'Not Significant')}")

# Step 4: Plot averaged volcano plots
if averaged_results:
    fig, axes = plt.subplots(1, len(averaged_results), figsize=(6 * len(averaged_results), 6))
    if len(averaged_results) == 1:
        axes = [axes]
    
    for ax, (condition, df) in zip(axes, averaged_results.items()):
        sns.scatterplot(data=df, x='Selected_lfc', y='minus_log10_pval',
                        hue='Category', palette=palette, ax=ax, legend=False, alpha=0.7)
        ax.axhline(-np.log10(pval_threshold), linestyle='--', color='black')
        ax.axvline(lfc_threshold, linestyle='--', color='black')
        ax.axvline(-lfc_threshold, linestyle='--', color='black')
        label_top_genes(ax, df)
        ax.set_title(f'{condition} (Averaged - Fisher\'s Method)')
        ax.set_xlabel('LFC')
        ax.set_ylabel('-log10(pval)')
    
    plt.tight_layout()
    plt.savefig("volcano_averaged_fisher.svg", format="svg")
    plt.show()

# Optional: Save a combined summary table
if averaged_results:
    summary_data = []
    for condition, df in averaged_results.items():
        summary_data.append({
            'Condition': condition,
            'Total_genes': len(df),
            'Enriched': sum(df['Category'] == 'Enriched'),
            'Depleted': sum(df['Category'] == 'Depleted'),
            'Not_Significant': sum(df['Category'] == 'Not Significant'),
            'Mean_LFC': df['Selected_lfc'].mean(),
            'Median_pval': df['Selected_pval'].median()
        })
    
    #summary_df = pd.DataFrame(summary_data)
    #summary_df.to_csv(os.path.join(output_dir, "summary_statistics.csv"), index=False)
    #print(f"\nSaved summary statistics to {os.path.join(output_dir, 'summary_statistics.csv')}")

In [None]:
import gseapy as gp
def run_enrichr(gene_list, gene_set_library, outdir):
    """
    Run an Enrichr over‑representation analysis against human gene sets.
    """
    enr = gp.enrichr(
        gene_list=gene_list,
        gene_sets=gene_set_library,    # e.g. 'KEGG_2021_Human' or 'GO_Biological_Process_2021'
        organism='Human',
        outdir=outdir,
        cutoff=0.05,
        no_plot=True
    )
    return enr.results


In [None]:
# pick your GO BP library
go_bp_lib = 'GO_Biological_Process_2025'

all_enrich = []
for cond in conditions:
    df = averaged_results[cond]
    depleted = df.query(
        "Selected_pval < @pval_threshold and Selected_lfc < -@lfc_threshold"
    )['id'].tolist()
    if not depleted:
        continue

    # ORA on KEGG (human)
    kegg  = run_enrichr(depleted, 'KEGG_2021_Human',             outdir=f"enrichr/{cond}/kegg")
    # ORA on GO BP (human, 2025)
    go_bp = run_enrichr(depleted, go_bp_lib,                      outdir=f"enrichr/{cond}/gobp")

    # take top 10 from each and assemble
    for kind, res in [('KEGG', kegg), ('GO_BP', go_bp)]:
        top = (
            res
            .sort_values('Adjusted P-value')
            .head(10)
            .assign(
                Condition=cond,
                Category=kind,
                GeneRatio=lambda d: d['Overlap']
                                    .str
                                    .split('/')
                                    .apply(lambda x: float(x[0]) / float(x[1])),
                MinusLog10Q=lambda d: -np.log10(d['Adjusted P-value'])
            )
        )
        all_enrich.append(
            top[['Condition','Category','Term','GeneRatio','Overlap','MinusLog10Q']]
        )

# combine and plot
enrich_df = pd.concat(all_enrich, axis=0)

In [None]:
import textwrap

sns.set_style("whitegrid")
sns.set_context("paper", font_scale=0.8)

def wrap_term(s, width=60):
    return "\n".join(textwrap.wrap(s, width))

plot_df = enrich_df.copy()
plot_df["Term_wrapped"] = plot_df["Term"].apply(wrap_term)
plot_df["Condition"] = pd.Categorical(plot_df["Condition"],
                                      categories=["AAV6","Ark312","Ark315"],
                                      ordered=True)

def plot_category(df_cat, category, outfile):
    sub = df_cat.query("Category == @category").copy()
    n_terms = sub["Term_wrapped"].nunique()

    fig, ax = plt.subplots(figsize=(7, 0.35*n_terms + 1))

    sns.scatterplot(
        data=sub,
        x="Condition", y="Term_wrapped",
        size="GeneRatio", hue="MinusLog10Q",
        sizes=(40, 200), palette="viridis",
        edgecolor="black", linewidth=0.2, alpha=0.9,
        ax=ax, legend=False
    )
    ax.set_xlabel(""); ax.set_ylabel("")
    ax.set_title(f"Top enriched pathways (depleted genes) – {category}",
                 fontsize=12, pad=6)
    ax.tick_params(axis="x", rotation=45, labelsize=8)
    ax.tick_params(axis="y", labelsize=8)
    ax.set_xlim(-0.4, 2.4)

    # Leave space on right for legends
    plt.tight_layout(rect=[0,0,0.8,1])

    # ---- Vertical colorbar axis ----
    # [x0, y0, width, height]
    cax = fig.add_axes([0.82, 0.25, 0.02, 0.55])
    norm = plt.Normalize(sub["MinusLog10Q"].min(), sub["MinusLog10Q"].max())
    sm = plt.cm.ScalarMappable(cmap="viridis", norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cax, orientation="vertical")
    cbar.set_label("-log10(Q-value)", fontsize=8)
    cbar.ax.tick_params(labelsize=7)

    # ---- GeneRatio size legend (separate axis below colorbar) ----
    lax = fig.add_axes([0.8, 0.10, 0.16, 0.12])  # small box
    lax.axis("off")
    lax.set_title("GeneRatio", fontsize=8, pad=0)

    gr_vals = np.linspace(sub["GeneRatio"].min(),
                          sub["GeneRatio"].max(), 4)
    for i, v in enumerate(gr_vals):
        lax.scatter(0, i, s=np.interp(v,
                   [sub["GeneRatio"].min(), sub["GeneRatio"].max()],
                   [40, 200]), color="gray", alpha=0.6)
        lax.text(0.4, i, f"{v:.3f}", va="center", fontsize=7)
    lax.set_ylim(-0.5, len(gr_vals)-0.5)
    lax.set_xlim(-0.2, 1)
    
    plt.savefig(outfile, format="svg", bbox_inches="tight")
    print(f"Saved {category} plot to {outfile}")
    plt.show()

plot_category(plot_df, "KEGG",  "KEGG_enrichment.svg")
plot_category(plot_df, "GO_BP", "GO_BP_enrichment.svg")


In [None]:
# Extract significantly depleted genes from averaged results
averaged_depleted = {}
all_averaged_depleted = set()

print("SIGNIFICANTLY DEPLETED GENES FROM AVERAGED RESULTS:")
print("="*60)

for condition in conditions:
    if condition in averaged_results:
        df = averaged_results[condition]
        depleted = df[df['Category'] == 'Depleted']
        gene_list = depleted['id'].tolist()
        averaged_depleted[condition] = gene_list
        all_averaged_depleted.update(gene_list)
        print(f"{condition}: {len(gene_list)} significantly depleted genes")

# Export gene lists for each condition
print("\n" + "="*60)
print("EXPORTING GENE LISTS")
print("="*60)

for condition in conditions:
    if condition in averaged_depleted:
        filename = f"depleted_genes_{condition}_averaged.txt"
        with open(filename, 'w') as f:
            for gene in sorted(averaged_depleted[condition]):
                f.write(f"{gene}\n")
        print(f"Exported {len(averaged_depleted[condition])} genes to '{filename}'")

# Compare similarity across conditions
print("\n" + "="*60)
print("CONDITION SIMILARITY ANALYSIS")
print("="*60)

# Create pairwise comparisons
from itertools import combinations
import pandas as pd

similarity_results = {}
for cond1, cond2 in combinations(conditions, 2):
    if cond1 in averaged_depleted and cond2 in averaged_depleted:
        genes1 = set(averaged_depleted[cond1])
        genes2 = set(averaged_depleted[cond2])
        
        conserved = genes1.intersection(genes2)
        unique_to_cond1 = genes1 - genes2
        unique_to_cond2 = genes2 - genes1
        total_unique = len(genes1.union(genes2))
        
        # Calculate similarity metrics
        jaccard = len(conserved) / total_unique if total_unique > 0 else 0
        overlap_coeff = len(conserved) / min(len(genes1), len(genes2)) if min(len(genes1), len(genes2)) > 0 else 0
        
        similarity_results[(cond1, cond2)] = {
            'cond1_total': len(genes1),
            'cond2_total': len(genes2),
            'conserved': len(conserved),
            'unique_cond1': len(unique_to_cond1),
            'unique_cond2': len(unique_to_cond2),
            'jaccard': jaccard,
            'overlap_coeff': overlap_coeff,
            'conserved_genes': list(conserved),
            'unique_cond1_genes': list(unique_to_cond1),
            'unique_cond2_genes': list(unique_to_cond2)
        }
        
        print(f"\n{cond1} vs {cond2}:")
        print(f"  {cond1}: {len(genes1)} genes")
        print(f"  {cond2}: {len(genes2)} genes")
        print(f"  Conserved: {len(conserved)} genes")
        print(f"  Unique to {cond1}: {len(unique_to_cond1)} genes")
        print(f"  Unique to {cond2}: {len(unique_to_cond2)} genes")
        print(f"  Overlap coefficient: {overlap_coeff:.3f}")

# Create similarity matrix
conditions_with_data = [c for c in conditions if c in averaged_depleted]
n_conditions = len(conditions_with_data)

overlap_matrix = pd.DataFrame(index=conditions_with_data, columns=conditions_with_data, dtype=float)

# Fill diagonal with 1.0
for condition in conditions_with_data:
    overlap_matrix.loc[condition, condition] = 1.0

# Fill off-diagonal elements
for (cond1, cond2), results in similarity_results.items():
    overlap_matrix.loc[cond1, cond2] = results['overlap_coeff']
    overlap_matrix.loc[cond2, cond1] = results['overlap_coeff']

print("\n" + "="*60)
print("SIMILARITY MATRICES")
print("="*60)

print("\nOverlap Coefficient Matrix:")
print(overlap_matrix.round(3))

# Create visual comparison
plt.figure(figsize=(8, 6))
sns.heatmap(overlap_matrix.astype(float), annot=True, cmap='Greens', 
            cbar_kws={'label': 'Overlap Coefficient'})
plt.title('Overlap Coefficient\nBetween Conditions')
plt.tight_layout()
plt.show()

print("\nAnalysis complete!")

In [None]:
# Update the label_top_genes function to also highlight CD7 if present
def label_top_genes_highlight_cd7(ax, df, top_n=10):
    top_enriched = df[df['Selected_lfc'] > 0].nsmallest(top_n, 'Selected_pval')
    top_depleted = df[df['Selected_lfc'] < 0].nsmallest(top_n, 'Selected_pval')
    top_hits = pd.concat([top_enriched, top_depleted])
    
    texts = []
    for _, row in top_hits.iterrows():
        texts.append(ax.text(row['Selected_lfc'], row['minus_log10_pval'], row['id'],
                             fontsize=6, ha='center', va='bottom'))

    # Highlight CD7 if it exists
    if 'CD7' in df['id'].values:
        cd7_row = df[df['id'] == 'CD7'].iloc[0]
        ax.scatter(cd7_row['Selected_lfc'], cd7_row['minus_log10_pval'],
                   color='black', s=60, zorder=5)
        ax.text(cd7_row['Selected_lfc'], cd7_row['minus_log10_pval'], 'CD7',
                fontsize=10, weight='bold', ha='center', va='bottom')

    adjust_text(texts, ax=ax)

    if 'KIAA0319L' in df['id'].values:
        cd7_row = df[df['id'] == 'KIAA0319L'].iloc[0]
        ax.scatter(cd7_row['Selected_lfc'], cd7_row['minus_log10_pval'],
                   color='black', s=60, zorder=5)
        ax.text(cd7_row['Selected_lfc'], cd7_row['minus_log10_pval'], 'KIAA0319L',
                fontsize=10, weight='bold', ha='center', va='bottom')

    adjust_text(texts, ax=ax)
    
# Re-run Step 4 with CD7 highlight
if averaged_results:
    fig, axes = plt.subplots(1, len(averaged_results), figsize=(6 * len(averaged_results), 6))
    if len(averaged_results) == 1:
        axes = [axes]
    for ax, (condition, df) in zip(axes, averaged_results.items()):
        sns.scatterplot(data=df, x='Selected_lfc', y='minus_log10_pval',
                        hue='Category', palette=palette, ax=ax, legend=False, alpha=0.7)
        ax.axhline(-np.log10(pval_threshold), linestyle='--', color='black')
        ax.axvline(lfc_threshold, linestyle='--', color='black')
        ax.axvline(-lfc_threshold, linestyle='--', color='black')
        label_top_genes_highlight_cd7(ax, df)
        ax.set_title(f'{condition} (Averaged)')
        ax.set_xlabel('LFC')
        ax.set_ylabel('-log10(pval)')
    plt.tight_layout()
    plt.savefig("volcano_averaged_highlight_cd7.svg", format="svg")
    plt.show()


In [None]:
#from this average list i want to do more

In [None]:
#filter based on cell surface expression
import mygene

# Load CSPA data
cspa_data = pd.read_csv('CSPA.csv')  # Replace with your actual file path
human_cspa = cspa_data[cspa_data['Organisme'] == 'Human']
unique_protein_ids = human_cspa['ID link'].unique()
print(f"Found {len(unique_protein_ids)} unique surface proteins in CSPA")

In [None]:
mg = mygene.MyGeneInfo()
# Query mygene to convert UniProt IDs to gene symbols
print("Converting UniProt IDs to gene symbols...")
gene_mapping = mg.querymany(unique_protein_ids, 
                           scopes='uniprot', 
                           fields='symbol', 
                           species='human',
                           returnall=True)

# Extract successful mappings
surface_genes = []
for result in gene_mapping['out']:
    if 'symbol' in result and result['symbol']:
        surface_genes.append(result['symbol'])

print(f"Successfully mapped {len(surface_genes)} proteins to gene symbols")

In [None]:
surface_genes = list(set(surface_genes))
print(f"Final unique surface genes: {len(surface_genes)}")

In [None]:
filtered_averaged_results = {}

for condition, df in averaged_results.items():
    # Filter for surface genes using the 'id' column
    filtered_df = df[df['id'].isin(surface_genes)]
    filtered_averaged_results[condition] = filtered_df
    
    print(f"\n{condition}:")
    print(f"Original size: {len(df)}")
    print(f"Filtered size: {len(filtered_df)}")
    print(f"Filtered out {len(df) - len(filtered_df)} genes")
    
    # Show category breakdown for filtered data
    print(f"Filtered {condition} Summary:")
    print(f"Enriched: {sum(filtered_df['Category'] == 'Enriched')}")
    print(f"Depleted: {sum(filtered_df['Category'] == 'Depleted')}")
    print(f"Not Significant: {sum(filtered_df['Category'] == 'Not Significant')}")

# Save filtered results
filtered_output_dir = "filtered_surface_results"
import os
os.makedirs(filtered_output_dir, exist_ok=True)

for condition, df in filtered_averaged_results.items():
    # Save filtered dataset
    output_file = os.path.join(filtered_output_dir, f"{condition}_surface_filtered.csv")
    df.to_csv(output_file, index=False)
    print(f"Saved {condition} surface-filtered results to {output_file}")

# Create volcano plots for surface proteins only
if filtered_averaged_results:
    fig, axes = plt.subplots(1, len(filtered_averaged_results), figsize=(6 * len(filtered_averaged_results), 6))
    if len(filtered_averaged_results) == 1:
        axes = [axes]
    
    for ax, (condition, df) in zip(axes, filtered_averaged_results.items()):
        sns.scatterplot(data=df, x='Selected_lfc', y='minus_log10_pval',
                        hue='Category', palette=palette, ax=ax, legend=False, alpha=0.7)
        ax.axhline(-np.log10(pval_threshold), linestyle='--', color='black')
        ax.axvline(lfc_threshold, linestyle='--', color='black')
        ax.axvline(-lfc_threshold, linestyle='--', color='black')
        label_top_genes(ax, df)
        ax.set_title(f'{condition} - Surface Proteins Only')
        ax.set_xlabel('LFC')
        ax.set_ylabel('-log10(pval)')
    
    plt.tight_layout()
    plt.savefig("volcano_surface_proteins_only.svg", format="svg")
    plt.show()

In [None]:
filtered_averaged_results

In [None]:
# Set publication-quality style
plt.style.use('default')
sns.set_palette("deep")

# Merge the two conditions on 'id' to align genes
aav6_data = filtered_averaged_results['AAV6'][['id', 'Selected_lfc', 'Category']].copy()
ark312_data = filtered_averaged_results['Ark312'][['id', 'Selected_lfc', 'Category']].copy()

# Rename columns for clarity
aav6_data = aav6_data.rename(columns={'Selected_lfc': 'AAV6_lfc', 'Category': 'AAV6_Category'})
ark312_data = ark312_data.rename(columns={'Selected_lfc': 'Ark312_lfc', 'Category': 'Ark312_Category'})

# Merge datasets
correlation_data = pd.merge(aav6_data, ark312_data, on='id', how='inner')

# Create combined significance categories
def get_combined_category(row):
    aav6_cat = row['AAV6_Category']
    ark312_cat = row['Ark312_Category']
    
    # Check for enriched
    if aav6_cat == 'Enriched' and ark312_cat == 'Enriched':
        return 'Enriched_Both'
    elif aav6_cat == 'Enriched' or ark312_cat == 'Enriched':
        return 'Enriched_Single'
    
    # Check for depleted
    elif aav6_cat == 'Depleted' and ark312_cat == 'Depleted':
        return 'Depleted_Both'
    elif aav6_cat == 'Depleted' or ark312_cat == 'Depleted':
        return 'Depleted_Single'
    
    # Not significant in either
    else:
        return 'Not_Significant'

correlation_data['Combined_Category'] = correlation_data.apply(get_combined_category, axis=1)

# Calculate correlation coefficient
spearman_r, spearman_p = spearmanr(correlation_data['AAV6_lfc'], correlation_data['Ark312_lfc'])

# Create color mapping based on combined significance categories
color_map = {
    'Enriched_Both': '#8B0000',      # Dark red - enriched in both
    'Enriched_Single': '#d62728',    # Original red - enriched in one
    'Depleted_Both': '#0f4c8c',      # Slightly darker blue - depleted in both
    'Depleted_Single': '#1f77b4',    # Original blue - depleted in one
    'Not_Significant': '#C0C0C0'     # Light gray - not significant
}

# Create the figure with specific dimensions for publication
fig, ax = plt.subplots(figsize=(6, 5), dpi=300)

# Create scatter plot for each category (plot in order so most significant are on top)
plot_order = ['Not_Significant', 'Depleted_Single', 'Enriched_Single', 'Depleted_Both', 'Enriched_Both']
for category in plot_order:
    mask = correlation_data['Combined_Category'] == category
    if mask.any():
        # Adjust point properties based on significance
        alpha_val = 0.5 if category == 'Not_Significant' else 0.8
        size_val = 8 if category == 'Not_Significant' else 15
        edge_color = 'white' if category != 'Not_Significant' else 'none'
        edge_width = 0.3 if category != 'Not_Significant' else 0
        
        # Create cleaner labels for legend
        legend_labels = {
            'Not_Significant': 'Not Significant',
            'Depleted_Single': 'Depleted (one condition)',
            'Depleted_Both': 'Depleted (both conditions)',
            'Enriched_Single': 'Enriched (one condition)',
            'Enriched_Both': 'Enriched (both conditions)'
        }
        
        ax.scatter(correlation_data.loc[mask, 'AAV6_lfc'], 
                  correlation_data.loc[mask, 'Ark312_lfc'],
                  c=color_map[category], 
                  label=legend_labels[category],
                  alpha=alpha_val,
                  s=size_val,
                  edgecolors=edge_color,
                  linewidth=edge_width)

# Add annotations for genes that are significant in Ark312
ark312_significant = correlation_data[correlation_data['Ark312_Category'].isin(['Enriched', 'Depleted'])]

# Use adjustText for better label positioning
from adjustText import adjust_text
texts = []
for _, row in ark312_significant.iterrows():
    # Add smaller initial offset to move labels slightly away from points
    offset_x = 0.05 if row['AAV6_lfc'] > 0 else -0.05
    offset_y = 0.05 if row['Ark312_lfc'] > 0 else -0.05
    
    text = ax.annotate(row['id'], 
                      xy=(row['AAV6_lfc'], row['Ark312_lfc']),
                      xytext=(row['AAV6_lfc'] + offset_x, row['Ark312_lfc'] + offset_y),
                      fontsize=6, 
                      alpha=0.8,
                      ha='center', va='center')
    texts.append(text)

# Adjust text positions to avoid overlap - no arrows
adjust_text(texts, ax=ax, 
            force_points=(0.1, 0.1), force_text=(0.1, 0.1))
lim = max(abs(correlation_data['AAV6_lfc'].min()), 
          abs(correlation_data['AAV6_lfc'].max()),
          abs(correlation_data['Ark312_lfc'].min()), 
          abs(correlation_data['Ark312_lfc'].max())) * 1.1

ax.plot([-lim, lim], [-lim, lim], 'k--', alpha=0.3, linewidth=1)

# Add horizontal and vertical lines at 0
ax.axhline(0, color='k', linestyle='-', alpha=0.2, linewidth=0.5)
ax.axvline(0, color='k', linestyle='-', alpha=0.2, linewidth=0.5)

# Formatting for publication quality
ax.set_xlabel('AAV6 Log₂ Fold Change', fontsize=12, fontweight='bold')
ax.set_ylabel('Ark312 Log₂ Fold Change', fontsize=12, fontweight='bold')
ax.set_title(f'Surface Protein Expression Correlation\nAAV6 vs Ark312 (ρ = {spearman_r:.3f}, P = {spearman_p:.2e})', 
             fontsize=8, fontweight='bold', pad=20)

# Set equal aspect ratio and limits
ax.set_xlim(-lim, lim)
ax.set_ylim(-lim, lim)
ax.set_aspect('equal')

# Customize legend - place outside the plot area
legend = ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=True, fontsize=9, markerscale=1.2)
legend.get_frame().set_facecolor('white')
legend.get_frame().set_alpha(0.95)
legend.get_frame().set_edgecolor('gray')
legend.get_frame().set_linewidth(0.5)

# Clean up spines and ticks
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(0.5)
ax.spines['bottom'].set_linewidth(0.5)
ax.tick_params(axis='both', which='major', labelsize=10, width=0.5)

# Ensure grid is subtle
ax.grid(True, alpha=0.1, linewidth=0.5)

plt.tight_layout()
plt.savefig('AAV6_vs_Ark312_correlation.svg', format='svg', dpi=300, bbox_inches='tight')

plt.show()

# Print detailed statistics
print(f"Correlation Analysis: AAV6 vs Ark312 (Surface Proteins)")
print(f"{'='*50}")
print(f"Number of genes: {len(correlation_data)}")
print(f"Spearman correlation: ρ = {spearman_r:.4f}, P = {spearman_p:.2e}")
print(f"\nCombined category breakdown:")
for category in correlation_data['Combined_Category'].value_counts().index:
    count = sum(correlation_data['Combined_Category'] == category)
    print(f"  {category.replace('_', ' ')}: {count} genes")

print(f"\nGenes significant in Ark312:")
print(f"{'='*30}")
for _, row in ark312_significant.iterrows():
    aav6_status = row['AAV6_Category']
    ark312_status = row['Ark312_Category']
    print(f"{row['id']:15} | AAV6: {aav6_status:15} | Ark312: {ark312_status:10} | LFC: ({row['AAV6_lfc']:+.2f}, {row['Ark312_lfc']:+.2f})")

In [None]:
# Set publication-quality style
plt.style.use('default')
sns.set_palette("deep")

# Merge the two conditions on 'id' to align genes
aav6_data = filtered_averaged_results['AAV6'][['id', 'Selected_lfc', 'Category']].copy()
ark315_data = filtered_averaged_results['Ark315'][['id', 'Selected_lfc', 'Category']].copy()

# Rename columns for clarity
aav6_data = aav6_data.rename(columns={'Selected_lfc': 'AAV6_lfc', 'Category': 'AAV6_Category'})
ark315_data = ark315_data.rename(columns={'Selected_lfc': 'Ark315_lfc', 'Category': 'Ark315_Category'})

# Merge datasets
correlation_data = pd.merge(aav6_data, ark315_data, on='id', how='inner')

# Create combined significance categories
def get_combined_category(row):
    aav6_cat = row['AAV6_Category']
    ark315_cat = row['Ark315_Category']
    
    # Check for enriched
    if aav6_cat == 'Enriched' and ark315_cat == 'Enriched':
        return 'Enriched_Both'
    elif aav6_cat == 'Enriched' or ark315_cat == 'Enriched':
        return 'Enriched_Single'
    
    # Check for depleted
    elif aav6_cat == 'Depleted' and ark315_cat == 'Depleted':
        return 'Depleted_Both'
    elif aav6_cat == 'Depleted' or ark315_cat == 'Depleted':
        return 'Depleted_Single'
    
    # Not significant in either
    else:
        return 'Not_Significant'

correlation_data['Combined_Category'] = correlation_data.apply(get_combined_category, axis=1)

# Calculate correlation coefficient
spearman_r, spearman_p = spearmanr(correlation_data['AAV6_lfc'], correlation_data['Ark315_lfc'])

# Create color mapping based on combined significance categories
color_map = {
    'Enriched_Both': '#8B0000',      # Dark red - enriched in both
    'Enriched_Single': '#d62728',    # Original red - enriched in one
    'Depleted_Both': '#0f4c8c',      # Slightly darker blue - depleted in both
    'Depleted_Single': '#1f77b4',    # Original blue - depleted in one
    'Not_Significant': '#C0C0C0'     # Light gray - not significant
}

# Create the figure with specific dimensions for publication
fig, ax = plt.subplots(figsize=(6, 5), dpi=300)

# Create scatter plot for each category (plot in order so most significant are on top)
plot_order = ['Not_Significant', 'Depleted_Single', 'Enriched_Single', 'Depleted_Both', 'Enriched_Both']
for category in plot_order:
    mask = correlation_data['Combined_Category'] == category
    if mask.any():
        # Adjust point properties based on significance
        alpha_val = 0.5 if category == 'Not_Significant' else 0.8
        size_val = 8 if category == 'Not_Significant' else 15
        edge_color = 'white' if category != 'Not_Significant' else 'none'
        edge_width = 0.3 if category != 'Not_Significant' else 0
        
        # Create cleaner labels for legend
        legend_labels = {
            'Not_Significant': 'Not Significant',
            'Depleted_Single': 'Depleted (one condition)',
            'Depleted_Both': 'Depleted (both conditions)',
            'Enriched_Single': 'Enriched (one condition)',
            'Enriched_Both': 'Enriched (both conditions)'
        }
        
        ax.scatter(correlation_data.loc[mask, 'AAV6_lfc'], 
                  correlation_data.loc[mask, 'Ark315_lfc'],
                  c=color_map[category], 
                  label=legend_labels[category],
                  alpha=alpha_val,
                  s=size_val,
                  edgecolors=edge_color,
                  linewidth=edge_width)

# Add annotations for genes that are significant in Ark315
ark312_significant = correlation_data[correlation_data['Ark315_Category'].isin(['Enriched', 'Depleted'])]

# Use adjustText for better label positioning
from adjustText import adjust_text
texts = []
for _, row in ark312_significant.iterrows():
    # Add smaller initial offset to move labels slightly away from points
    offset_x = 0.05 if row['AAV6_lfc'] > 0 else -0.05
    offset_y = 0.05 if row['Ark315_lfc'] > 0 else -0.05
    
    text = ax.annotate(row['id'], 
                      xy=(row['AAV6_lfc'], row['Ark315_lfc']),
                      xytext=(row['AAV6_lfc'] + offset_x, row['Ark315_lfc'] + offset_y),
                      fontsize=6, 
                      alpha=0.8,
                      ha='center', va='center')
    texts.append(text)

# Adjust text positions to avoid overlap - no arrows
adjust_text(texts, ax=ax, 
            force_points=(0.1, 0.1), force_text=(0.1, 0.1))
lim = max(abs(correlation_data['AAV6_lfc'].min()), 
          abs(correlation_data['AAV6_lfc'].max()),
          abs(correlation_data['Ark315_lfc'].min()), 
          abs(correlation_data['Ark315_lfc'].max())) * 1.1

ax.plot([-lim, lim], [-lim, lim], 'k--', alpha=0.3, linewidth=1)

# Add horizontal and vertical lines at 0
ax.axhline(0, color='k', linestyle='-', alpha=0.2, linewidth=0.5)
ax.axvline(0, color='k', linestyle='-', alpha=0.2, linewidth=0.5)

# Formatting for publication quality
ax.set_xlabel('AAV6 Log₂ Fold Change', fontsize=12, fontweight='bold')
ax.set_ylabel('Ark315 Log₂ Fold Change', fontsize=12, fontweight='bold')
ax.set_title(f'Surface Protein Expression Correlation\nAAV6 vs Ark315 (ρ = {spearman_r:.3f}, P = {spearman_p:.2e})', 
             fontsize=8, fontweight='bold', pad=20)

# Set equal aspect ratio and limits
ax.set_xlim(-lim, lim)
ax.set_ylim(-lim, lim)
ax.set_aspect('equal')

# Customize legend - place outside the plot area
legend = ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=True, fontsize=9, markerscale=1.2)
legend.get_frame().set_facecolor('white')
legend.get_frame().set_alpha(0.95)
legend.get_frame().set_edgecolor('gray')
legend.get_frame().set_linewidth(0.5)

# Clean up spines and ticks
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(0.5)
ax.spines['bottom'].set_linewidth(0.5)
ax.tick_params(axis='both', which='major', labelsize=10, width=0.5)

# Ensure grid is subtle
ax.grid(True, alpha=0.1, linewidth=0.5)

plt.tight_layout()
plt.savefig('AAV6_vs_Ark315_correlation.svg', format='svg', dpi=300, bbox_inches='tight')

plt.show()

# Print detailed statistics
print(f"Correlation Analysis: AAV6 vs Ark315 (Surface Proteins)")
print(f"{'='*50}")
print(f"Number of genes: {len(correlation_data)}")
print(f"Spearman correlation: ρ = {spearman_r:.4f}, P = {spearman_p:.2e}")
print(f"\nCombined category breakdown:")
for category in correlation_data['Combined_Category'].value_counts().index:
    count = sum(correlation_data['Combined_Category'] == category)
    print(f"  {category.replace('_', ' ')}: {count} genes")

print(f"\nGenes significant in Ark312:")
print(f"{'='*30}")
for _, row in ark312_significant.iterrows():
    aav6_status = row['AAV6_Category']
    ark312_status = row['Ark315_Category']
    print(f"{row['id']:15} | AAV6: {aav6_status:15} | Ark315: {ark315_status:10} | LFC: ({row['AAV6_lfc']:+.2f}, {row['Ark315_lfc']:+.2f})")
    

In [None]:
#now do pathway analysis

In [None]:
# Define direct comparisons without requiring common gene intersection
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

comparisons = [
    ('AAV6', 'Ark312'),
    ('AAV6', 'Ark315'),
    ('Ark315', 'Ark312'),
]

for ax, (cond1, cond2) in zip(axes, comparisons):
    df1 = averaged_results[cond1].set_index("id")["Selected_lfc"]
    df2 = averaged_results[cond2].set_index("id")["Selected_lfc"]

    df = pd.concat([df1, df2], axis=1, keys=[f'{cond1}_LFC', f'{cond2}_LFC'])
    df = df.dropna(how='all')

    # Compute delta for rows where both values are present
    df['distance'] = np.sqrt(
    (df[f'{cond1}_LFC'] - df[f'{cond1}_LFC'].mean())**2 +
    (df[f'{cond2}_LFC'] - df[f'{cond2}_LFC'].mean())**2)
    threshold = df['distance'].mean() + 4 * df['distance'].std()
    df['outlier'] = df['distance'] > threshold
    
    # Scatter plot
    ax.scatter(df[f'{cond1}_LFC'], df[f'{cond2}_LFC'], color='gray', alpha=0.5, s=10)

    # Highlight outliers
    outliers = df[df["outlier"] & df[f'{cond1}_LFC'].notna() & df[f'{cond2}_LFC'].notna()]
    ax.scatter(outliers[f'{cond1}_LFC'], outliers[f'{cond2}_LFC'], color='black', s=20)

    # Annotate
    for gene, row in outliers.iterrows():
        ax.text(row[f'{cond1}_LFC'], row[f'{cond2}_LFC'], gene, fontsize=6, ha='center', va='bottom')

    # Reference line
    ax.plot([-1, 1], [-1, 1], linestyle='--', color='black', linewidth=1)
    ax.set_title(f'{cond1} vs {cond2}')
    ax.set_xlabel(f'{cond1} LFC')
    ax.set_ylabel(f'{cond2} LFC')

plt.tight_layout()
#plt.savefig("LFC_pairwise_comparisons.svg", format="svg")
plt.show()


In [None]:
# Plot LFC comparisons across conditions with outlier detection in all quadrants
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Print available keys for manual selection
available_conditions = list(averaged_results.keys())

# Define comparisons using available keys
comparisons = [
    ('Ark312', 'Ark315'),
    ('Ark312', 'AAV6'),
    ('AAV6', 'Ark315'),
]

for ax, (cond1, cond2) in zip(axes, comparisons):
    df1 = averaged_results[cond1].set_index("id")["Selected_lfc"]
    df2 = averaged_results[cond2].set_index("id")["Selected_lfc"]

    # Align by outer join to retain all genes with at least one value
    df = pd.concat([df1, df2], axis=1, keys=[f'{cond1}_LFC', f'{cond2}_LFC'])
    df = df.dropna(how='all')

    # Compute delta for rows where both values are present
    df['delta'] = df[f'{cond1}_LFC'] - df[f'{cond2}_LFC']
    delta_valid = df.dropna(subset=['delta'])
    mean = delta_valid['delta'].mean()
    sd = delta_valid['delta'].std()

    # Flag all extreme differences regardless of quadrant
    df['outlier'] = False
    df.loc[delta_valid.index, 'outlier'] = abs(df.loc[delta_valid.index, 'delta'] - mean) > 4 * sd

    # Plot all
    ax.scatter(df[f'{cond1}_LFC'], df[f'{cond2}_LFC'], color='gray', alpha=0.5, s=10)

    # Highlight outliers
    outliers = df[df['outlier'] & df[f'{cond1}_LFC'].notna() & df[f'{cond2}_LFC'].notna()]
    ax.scatter(outliers[f'{cond1}_LFC'], outliers[f'{cond2}_LFC'], color='black', s=20)

    for gene, row in outliers.iterrows():
        ax.text(row[f'{cond1}_LFC'], row[f'{cond2}_LFC'], gene, fontsize=6,  ha='center', va='bottom')

    # Axes and labels
    ax.plot([-1, 1], [-1, 1], linestyle='--', color='black', linewidth=1)
    ax.set_title(f'{cond1} vs {cond2}')
    ax.set_xlabel(f'{cond1} LFC')
    ax.set_ylabel(f'{cond2} LFC')

plt.tight_layout()
plt.savefig("LFC_outliers.svg", format="svg")
plt.show()

In [None]:
#guide level analysis

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from glob import glob
import os

# Setup
conditions = ['AAV6', 'Ark315', 'Ark312']
donors = ['Donor20', 'Donor21']
base_path = '/data/JE_misc/GW_screen'
pval_threshold = 0.05
lfc_threshold = 0.5
palette = {
    'Enriched': '#D73027',   # muted deep red
    'Depleted': '#4575B4',   # muted deep blue
    'Not Significant': '#BDBDBD'  # soft gray
}

def process_sgrna_file(file_path):
    """
    Process sgRNA summary file - handles MAGeCK format with p.low, p.high, p.twosided
    """
    df = pd.read_csv(file_path, sep='\t')
    
    try:
        # Get LFC column
        if 'LFC' in df.columns:
            lfc_col = 'LFC'
        elif 'lfc' in df.columns:
            lfc_col = 'lfc'
        elif 'logFC' in df.columns:
            lfc_col = 'logFC'
        else:
            raise ValueError("No LFC column found")
        
        # Get sgRNA ID column
        if 'sgrna' in df.columns:
            id_col = 'sgrna'
        elif 'sgRNA' in df.columns:
            id_col = 'sgRNA'
        elif 'guide' in df.columns:
            id_col = 'guide'
        elif 'Guide' in df.columns:
            id_col = 'Guide'
        else:
            # Use first column as ID
            id_col = df.columns[0]
        
        # Handle MAGeCK's directional p-values
        # MAGeCK provides p.low (depletion), p.high (enrichment), p.twosided
        # We'll select the appropriate p-value based on LFC direction
        selected_pvals = []
        selected_lfcs = []
        directions = []
        
        for _, row in df.iterrows():
            lfc = row[lfc_col]
            
            if lfc > 0:
                # Positive LFC = enrichment, use p.high
                pval = row['p.high']
                direction = 'Enrichment'
            else:
                # Negative LFC = depletion, use p.low
                pval = row['p.low']
                direction = 'Depletion'
            
            selected_pvals.append(pval)
            selected_lfcs.append(lfc)
            directions.append(direction)
        
        # Create processed dataframe
        processed_df = pd.DataFrame({
            'id': df[id_col],
            'Selected_pval': selected_pvals,
            'Selected_lfc': selected_lfcs,
            'Direction': directions
        })
        
        # Add gene information if available
        if 'Gene' in df.columns:
            processed_df['Gene'] = df['Gene']
        elif 'gene' in df.columns:
            processed_df['Gene'] = df['gene']
        
        # Add FDR if available
        if 'FDR' in df.columns:
            processed_df['FDR'] = df['FDR']
        
        # Remove any rows with NaN values
        processed_df = processed_df.dropna(subset=['Selected_pval', 'Selected_lfc'])
        
        # Calculate -log10 p-value
        processed_df['minus_log10_pval'] = -np.log10(processed_df['Selected_pval'].clip(lower=1e-300))
        
        # Categorize sgRNAs
        def categorize(row):
            if row['Selected_pval'] < pval_threshold:
                if row['Selected_lfc'] > lfc_threshold:
                    return 'Enriched'
                elif row['Selected_lfc'] < -lfc_threshold:
                    return 'Depleted'
            return 'Not Significant'
        
        processed_df['Category'] = processed_df.apply(categorize, axis=1)
        
        return processed_df
        
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        print(f"Available columns: {df.columns.tolist()}")
        return None

from adjustText import adjust_text

def label_top_sgrnas(ax, df, top_n=10):
    """
    Label top sgRNAs instead of genes
    """
    # Get top enriched and depleted sgRNAs
    top_enriched = df[df['Selected_lfc'] > 0].nsmallest(top_n, 'Selected_pval')
    top_depleted = df[df['Selected_lfc'] < 0].nsmallest(top_n, 'Selected_pval')
    top_hits = pd.concat([top_enriched, top_depleted])
    
    texts = []
    for _, row in top_hits.iterrows():
        # Use Gene name if available, otherwise use sgRNA ID
        if 'Gene' in df.columns and pd.notna(row.get('Gene')):
            label = f"{row['Gene']}_{row['id']}"  # Gene_sgRNA format
        else:
            label = row['id']
        
        texts.append(ax.text(row['Selected_lfc'], row['minus_log10_pval'], label,
                             fontsize=6, ha='center', va='bottom'))
    
    # Adjust text to prevent overlaps
    if texts:
        adjust_text(texts, ax=ax)

# Step 1: Process all available donor-condition combinations for sgRNA data
all_results = {}
for condition in conditions:
    for donor in donors:
        pattern = f"{base_path}/{donor}/{condition}/**/*sgrna_summary.txt"
        matches = glob(pattern, recursive=True)
        if matches:
            df = process_sgrna_file(matches[0])  # take first match
            if df is not None:
                all_results[(condition, donor)] = df
                print(f"Processed {condition}-{donor}: {len(df)} sgRNAs")
            else:
                print(f"Failed to process {condition}-{donor}")
        else:
            print(f"No sgrna_summary.txt found for {condition}-{donor}")

# Step 2: Plot volcano plots for each donor-condition
if all_results:
    fig, axes = plt.subplots(len(conditions), len(donors), figsize=(14, 12))
    
    for i, condition in enumerate(conditions):
        for j, donor in enumerate(donors):
            ax = axes[i, j]
            key = (condition, donor)
            if key in all_results:
                df = all_results[key]
                sns.scatterplot(data=df, x='Selected_lfc', y='minus_log10_pval',
                                hue='Category', palette=palette, ax=ax, legend=False, alpha=0.7)
                ax.axhline(-np.log10(pval_threshold), linestyle='--', color='black')
                ax.axvline(lfc_threshold, linestyle='--', color='black')
                ax.axvline(-lfc_threshold, linestyle='--', color='black')
                label_top_sgrnas(ax, df)
                ax.set_title(f'{condition} - {donor}\n({len(df)} sgRNAs)')
            else:
                ax.set_title(f'{condition} - {donor}\n(No file found)')
            ax.set_xlabel('LFC')
            ax.set_ylabel('-log10(pval)')
    
    plt.tight_layout()
    #plt.savefig("volcano_sgrna_separate.svg", format="svg")
    plt.show()
    
    # Print summary statistics
    print("\n=== Summary Statistics ===")
    for (condition, donor), df in all_results.items():
        enriched = sum(df['Category'] == 'Enriched')
        depleted = sum(df['Category'] == 'Depleted')
        total = len(df)
        print(f"{condition}-{donor}: {total} sgRNAs, {enriched} enriched, {depleted} depleted")
else:
    print("No sgRNA data found to plot!")

In [None]:
#find the range for non editing controls
from scipy import stats

# Get non-targeting control LFCs for background distribution
ntc_lfcs = []
for (condition, donor), df in all_results.items():
    ntc_guides = df[df['Gene'] == 'Non-Targeting Control']
    ntc_lfcs.extend(ntc_guides['Selected_lfc'].values)

# Convert to numpy array for easier calculations
ntc_lfcs = np.array(ntc_lfcs)

# Calculate distribution statistics
ntc_mean = np.mean(ntc_lfcs)
ntc_std = np.std(ntc_lfcs)
ntc_median = np.median(ntc_lfcs)

print("Non-Targeting Control LFC Distribution:")
print(f"Mean: {ntc_mean:.4f}")
print(f"Median: {ntc_median:.4f}")
print(f"Standard Deviation: {ntc_std:.4f}")
print(f"Range: {np.min(ntc_lfcs):.4f} to {np.max(ntc_lfcs):.4f}")

# Define significance thresholds based on standard deviations
# Common thresholds: 1.96 SD (95% CI), 2 SD, 2.5 SD, 3 SD
sd_thresholds = [1.96, 2, 2.5, 3]

print("\nSignificance Thresholds (based on NTC distribution):")
for sd in sd_thresholds:
    upper_threshold = ntc_mean + sd * ntc_std
    lower_threshold = ntc_mean - sd * ntc_std
    
    # Calculate what percentage of NTCs fall within this range
    within_range = np.sum((ntc_lfcs >= lower_threshold) & (ntc_lfcs <= upper_threshold))
    percentage_within = (within_range / len(ntc_lfcs)) * 100
    
    print(f"{sd} SD: LFC outside [{lower_threshold:.4f}, {upper_threshold:.4f}] "
          f"({percentage_within:.1f}% of NTCs within range)")

# Test for normality of NTC distribution
shapiro_stat, shapiro_p = stats.shapiro(ntc_lfcs)
print(f"\nNormality Test (Shapiro-Wilk):")
print(f"Statistic: {shapiro_stat:.4f}, p-value: {shapiro_p:.4f}")
print(f"Distribution is {'normal' if shapiro_p > 0.05 else 'not normal'} (p > 0.05)")

# Recommended threshold based on common practice
# 2 SD is commonly used, but 1.96 SD corresponds to 95% confidence interval
recommended_sd = 2
recommended_upper = ntc_mean + recommended_sd * ntc_std
recommended_lower = ntc_mean - recommended_sd * ntc_std

print(f"\nRecommended Threshold ({recommended_sd} SD):")
print(f"Significant if LFC < {recommended_lower:.4f} (depleted)")
print(f"Significant if LFC > {recommended_upper:.4f} (enriched)")

# Visualize the distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Histogram with normal overlay
ax1.hist(ntc_lfcs, bins=30, density=True, alpha=0.7, color='lightblue', edgecolor='black')
ax1.axvline(ntc_mean, color='red', linestyle='--', label=f'Mean: {ntc_mean:.3f}')
ax1.axvline(recommended_lower, color='orange', linestyle='--', label=f'-{recommended_sd}SD: {recommended_lower:.3f}')
ax1.axvline(recommended_upper, color='orange', linestyle='--', label=f'+{recommended_sd}SD: {recommended_upper:.3f}')

# Overlay normal distribution
x = np.linspace(ntc_lfcs.min(), ntc_lfcs.max(), 100)
normal_curve = stats.norm.pdf(x, ntc_mean, ntc_std)
ax1.plot(x, normal_curve, 'r-', linewidth=2, label='Normal fit')

ax1.set_xlabel('Log Fold Change')
ax1.set_ylabel('Density')
ax1.set_title('Non-Targeting Control LFC Distribution')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Q-Q plot to check normality
stats.probplot(ntc_lfcs, dist="norm", plot=ax2)
ax2.set_title('Q-Q Plot (Normal Distribution)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('ntc_distribution_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

# Function to classify genes based on NTC-derived thresholds
def classify_by_ntc_threshold(lfc, sd_threshold=2):
    """
    Classify genes based on NTC-derived LFC thresholds
    
    Parameters:
    lfc: Log fold change value
    sd_threshold: Number of standard deviations from NTC mean
    
    Returns:
    'Enriched', 'Depleted', or 'Not Significant'
    """
    upper_thresh = ntc_mean + sd_threshold * ntc_std
    lower_thresh = ntc_mean - sd_threshold * ntc_std
    
    if lfc > upper_thresh:
        return 'Enriched'
    elif lfc < lower_thresh:
        return 'Depleted'
    else:
        return 'Not Significant'

# Example usage - apply to your data
#print(f"\nExample classification using {recommended_sd} SD threshold:")
#example_lfcs = [-1.5, -0.5, 0.0, 0.5, 1.5]
#for lfc in example_lfcs:
#    classification = classify_by_ntc_threshold(lfc, recommended_sd)
#    print(f"LFC {lfc:5.1f}: {classification}")

# Calculate empirical percentiles (alternative approach)
print(f"\nEmpirical Percentiles of NTC Distribution:")
percentiles = [1, 2.5, 5, 95, 97.5, 99]
for p in percentiles:
    value = np.percentile(ntc_lfcs, p)
    print(f"{p:4.1f}th percentile: {value:.4f}")

# Suggest final thresholds
print(f"\nFinal Recommendations:")
print(f"1. Statistical approach (2 SD): LFC < {recommended_lower:.3f} or > {recommended_upper:.3f}")
print(f"2. Empirical approach (2.5th/97.5th percentile): LFC < {np.percentile(ntc_lfcs, 2.5):.3f} or > {np.percentile(ntc_lfcs, 97.5):.3f}")
print(f"3. Conservative approach (3 SD): LFC < {ntc_mean - 3*ntc_std:.3f} or > {ntc_mean + 3*ntc_std:.3f}")

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
import matplotlib.patches as mpatches


        
def create_guide_distribution_plot(all_results, target_genes, conditions, donors, 
                                 figsize=(12, 8), alpha_bg=0.3, alpha_target=0.8):
    """
    Create a plot showing LFC distribution of all guides with highlights for target genes
    
    Parameters:
    - all_results: dictionary with (condition, donor) keys and dataframes
    - target_genes: list of gene names to highlight
    - conditions: list of conditions to plot
    - donors: list of donors (will be averaged using Fisher's method)
    """
    
    # First, let's combine donors using Fisher's method (like in your previous code)
    from scipy import stats
    
    def fisher_combine_pvalues(p1, p2):
        if pd.isna(p1) or pd.isna(p2):
            return np.nan
        if p1 == 0 or p2 == 0:
            return 1e-300
        chi_squared = -2 * (np.log(p1) + np.log(p2))
        combined_pval = 1 - stats.chi2.cdf(chi_squared, df=4)
        return combined_pval
    
    # Combine data across donors
    combined_results = {}
    for condition in conditions:
        df1 = all_results.get((condition, donors[0]))
        df2 = all_results.get((condition, donors[1]))
        if df1 is not None and df2 is not None:
            merged = pd.merge(df1, df2, on='id', suffixes=('_1', '_2'))
            # Average LFC
            merged['Selected_lfc'] = merged[['Selected_lfc_1', 'Selected_lfc_2']].mean(axis=1)
            # Combine p-values using Fisher's method
            merged['Selected_pval'] = merged.apply(
                lambda row: fisher_combine_pvalues(row['Selected_pval_1'], row['Selected_pval_2']), 
                axis=1
            )
            # Keep gene information
            if 'Gene_1' in merged.columns:
                merged['Gene'] = merged['Gene_1']
            elif 'Gene_2' in merged.columns:
                merged['Gene'] = merged['Gene_2']
            
            combined_results[condition] = merged
    
    # Create the plot
    fig, axes = plt.subplots(len(target_genes), len(conditions), 
                            figsize=figsize, sharey=True, sharex=True)
    
    # Handle single condition case
    if len(conditions) == 1:
        axes = axes.reshape(-1, 1)
    
    # Define colors for each gene
    gene_colors = plt.cm.Set3(np.linspace(0, 1, len(target_genes)))
    
    for j, condition in enumerate(conditions):
        if condition not in combined_results:
            continue
            
        df = combined_results[condition]
        all_lfcs = df['Selected_lfc'].values
        
        # Get LFC range for consistent x-axis
        lfc_min, lfc_max = np.percentile(all_lfcs, [1, 99])
        lfc_range = lfc_max - lfc_min
        
        for i, gene in enumerate(target_genes):
            ax = axes[i, j]
            
            # Plot background distribution of all guides as density
            ax.hist(all_lfcs, bins=50, alpha=alpha_bg, color='lightgray', 
                   density=True, label='All guides' if i == 0 and j == 0 else "")
            
            # Get guides for this specific gene
            gene_guides = df[df['Gene'] == gene]
            
            if len(gene_guides) > 0:
                # Plot individual guide LFCs as vertical lines
                for idx, (_, guide_row) in enumerate(gene_guides.iterrows()):
                    lfc = guide_row['Selected_lfc']
                    guide_id = guide_row['id']
                    
                    # Create thick vertical line for each guide
                    ax.axvline(x=lfc, color=gene_colors[i], alpha=alpha_target, 
                              linewidth=4, label=f'{gene} guides' if idx == 0 else "")
                    
                    # Add guide ID as text annotation
                    y_pos = ax.get_ylim()[1] * 0.9 - idx * 0.1
                    ax.text(lfc, y_pos, guide_id.split('_')[-1] if '_' in guide_id else guide_id, 
                           rotation=90, fontsize=8, ha='center', va='top',
                           color=gene_colors[i], weight='bold')
                
                # Add gene name and summary statistics
                mean_lfc = gene_guides['Selected_lfc'].mean()
                n_guides = len(gene_guides)
                
                # Add summary text box
                textstr = f'{gene}\nn={n_guides}\nmean LFC={mean_lfc:.2f}'
                props = dict(boxstyle='round', facecolor=gene_colors[i], alpha=0.3)
                ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=10,
                       verticalalignment='top', bbox=props)
            
            # Formatting
            ax.set_xlim(lfc_min - lfc_range*0.1, lfc_max + lfc_range*0.1)
            ax.grid(True, alpha=0.3)
            
            # Labels
            if i == 0:
                ax.set_title(f'{condition}', fontsize=14, weight='bold')
            if j == 0:
                ax.set_ylabel(f'{gene}\nDensity', fontsize=12, weight='bold')
            if i == len(target_genes) - 1:
                ax.set_xlabel('log₂ Fold Change', fontsize=12)
    
    # Add overall legend
    legend_elements = [
        mpatches.Patch(color='lightgray', alpha=alpha_bg, label='All guides distribution'),
        mpatches.Patch(color='black', alpha=alpha_target, label='Target gene guides')
    ]
    fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.98, 0.98))
    
    plt.tight_layout()
    plt.subplots_adjust(right=0.85)
    return fig

# Define your target genes
target_genes = ['CD7', 'KIAA0319L', 'RPP21', 'PTEN', 'RASA2']

# Create the plot
if all_results:
    fig = create_guide_distribution_plot(
        all_results=all_results,
        target_genes=target_genes,
        conditions=conditions,
        donors=donors,
        figsize=(15, 12),
        alpha_bg=0.3,
        alpha_target=0.8
    )
    
    plt.savefig("guide_distribution_plot.svg", format="svg", dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print summary statistics for each gene
    print("\n=== Target Gene Summary ===")
    for condition in conditions:
        if (condition, donors[0]) in all_results and (condition, donors[1]) in all_results:
            df1 = all_results[(condition, donors[0])]
            df2 = all_results[(condition, donors[1])]
            
            print(f"\n{condition}:")
            for gene in target_genes:
                guides1 = df1[df1['Gene'] == gene]
                guides2 = df2[df2['Gene'] == gene]
                
                if len(guides1) > 0 or len(guides2) > 0:
                    print(f"  {gene}:")
                    print(f"    {donors[0]}: {len(guides1)} guides, mean LFC: {guides1['Selected_lfc'].mean():.3f}")
                    print(f"    {donors[1]}: {len(guides2)} guides, mean LFC: {guides2['Selected_lfc'].mean():.3f}")
else:
    print("No data available for plotting!")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
import matplotlib.patches as mpatches

def create_guide_distribution_bars(all_results, target_genes, conditions, donors, 
                                 figsize=(12, 10), bar_height=0.8):
    """
    Create horizontal bar plots showing guide LFC distributions for target genes
    Each gene gets its own panel with bars for each condition-donor combination
    """
    
    # Create figure with subplots for each gene
    fig, axes = plt.subplots(len(target_genes), 1, figsize=figsize, 
                            sharex=True, gridspec_kw={'hspace': 0.3})
    
    # Handle single gene case
    if len(target_genes) == 1:
        axes = [axes]
    
    # Define colors for each condition
    condition_colors = {
        'AAV6': '#FF6B6B',      # Red
        'Ark315': '#4ECDC4',    # Teal  
        'Ark312': '#45B7D1',    # Blue
    }
    
    # Get overall LFC range for consistent x-axis
    all_lfcs = []
    for (condition, donor), df in all_results.items():
        all_lfcs.extend(df['Selected_lfc'].values)
    
    lfc_min, lfc_max = np.percentile(all_lfcs, [1, 99])
    lfc_range = lfc_max - lfc_min
    x_min, x_max = lfc_min - lfc_range*0.1, lfc_max + lfc_range*0.1
    
    # Create bins for background distribution
    n_bins = 100
    bin_edges = np.linspace(x_min, x_max, n_bins + 1)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    bin_width = bin_edges[1] - bin_edges[0]
    
    # Calculate background distribution (all guides)
    all_lfc_hist, _ = np.histogram(all_lfcs, bins=bin_edges, density=True)
    # Normalize to 0-1 for background intensity
    all_lfc_hist_norm = all_lfc_hist / np.max(all_lfc_hist)
    
    for gene_idx, gene in enumerate(target_genes):
        ax = axes[gene_idx]
        
        # Create y-positions for each condition-donor combination
        # Loop through conditions first, then donors (so donors are adjacent)
        y_positions = []
        y_labels = []
        
        for i, condition in enumerate(conditions):
            for j, donor in enumerate(donors):
                y_pos = i * len(donors) + j
                y_positions.append(y_pos)
                y_labels.append(f"{condition}-{donor}")
        
        # Plot background distribution as grayscale bars for each row
        for y_pos in y_positions:
            for bin_idx, (center, intensity) in enumerate(zip(bin_centers, all_lfc_hist_norm)):
                color_intensity = 1 - intensity * 0.7  # Lighter = higher density
                ax.barh(y_pos, bin_width, left=center - bin_width/2, 
                       height=bar_height, color=str(color_intensity), 
                       alpha=0.8, edgecolor='none')
        
        # Plot guides for this gene
        for i, condition in enumerate(conditions):
            for j, donor in enumerate(donors):
                y_pos = i * len(donors) + j
                
                # Get data for this condition-donor combination
                if (condition, donor) in all_results:
                    df = all_results[(condition, donor)]
                    gene_guides = df[df['Gene'] == gene]
                    
                    if len(gene_guides) > 0:
                        # Plot individual guides as thick vertical lines
                        for _, guide_row in gene_guides.iterrows():
                            lfc = guide_row['Selected_lfc']
                            
                            # Create thick vertical line (no text labels)
                            ax.plot([lfc, lfc], [y_pos - bar_height/2, y_pos + bar_height/2], 
                                   color=condition_colors[condition], linewidth=4, 
                                   alpha=0.9, solid_capstyle='round')
        
        # Formatting
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(-0.5, len(y_positions) - 0.5)
        ax.set_yticks(y_positions)
        ax.set_yticklabels(y_labels)
        ax.grid(True, alpha=0.3, axis='x')
        
        # Add gene name as title
        ax.set_ylabel(f'{gene}', fontsize=14, weight='bold', rotation=0, 
                     labelpad=50, ha='right', va='center')
        
        # Add vertical line at LFC = 0
        ax.axvline(x=0, color='black', linestyle='--', alpha=0.5, linewidth=1)
        
        # Remove top and right spines
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    # Set x-label only for bottom plot
    axes[-1].set_xlabel('log₂ Fold Change', fontsize=12, weight='bold')
    
    # Add legend
    legend_elements = []
    for condition, color in condition_colors.items():
        legend_elements.append(mpatches.Patch(color=color, label=condition))
    legend_elements.append(mpatches.Patch(color='gray', alpha=0.5, label='All guides density'))
    
    fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.98, 0.98))
    
    # Add title
    fig.suptitle('Guide LFC Distribution by Gene and Condition', fontsize=16, weight='bold')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.93, right=0.85)
    return fig

# Define your target genes
target_genes = ['CD7', 'KIAA0319L', 'RPP21', 'PTEN', 'RASA2']

# Create the plot
if all_results:
    fig = create_guide_distribution_bars(
        all_results=all_results,
        target_genes=target_genes,
        conditions=conditions,
        donors=donors,
        figsize=(14, 12),
        bar_height=0.8
    )
    
    plt.savefig("guide_distribution_bars.svg", format="svg", dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print summary statistics for each gene
    print("\n=== Target Gene Guide Summary ===")
    for gene in target_genes:
        print(f"\n{gene}:")
        for condition in conditions:
            for donor in donors:
                if (condition, donor) in all_results:
                    df = all_results[(condition, donor)]
                    gene_guides = df[df['Gene'] == gene]
                    
                    if len(gene_guides) > 0:
                        mean_lfc = gene_guides['Selected_lfc'].mean()
                        std_lfc = gene_guides['Selected_lfc'].std()
                        print(f"  {condition}-{donor}: {len(gene_guides)} guides, "
                              f"mean LFC: {mean_lfc:.3f} ± {std_lfc:.3f}")
                    else:
                        print(f"  {condition}-{donor}: No guides found")
else:
    print("No data available for plotting!")

In [None]:
#new code for non targetting control

In [None]:
import matplotlib.patches as mpatches
def create_guide_distribution_bars(all_results, target_genes, conditions, donors, 
                                 figsize=(12, 10), bar_height=0.8):
    """
    Create horizontal bar plots showing guide LFC distributions for target genes
    Each gene gets its own panel with bars for each condition-donor combination
    Background shows non-targeting control guide distribution
    """
    
    # Create figure with subplots for each gene
    fig, axes = plt.subplots(len(target_genes), 1, figsize=figsize, 
                            sharex=True, gridspec_kw={'hspace': 0.3})
    
    # Handle single gene case
    if len(target_genes) == 1:
        axes = [axes]
    
    # Define colors for each condition
    condition_colors = {
        'AAV6': '#FF6B6B',      # Red
        'Ark315': '#4ECDC4',    # Teal  
        'Ark312': '#45B7D1',    # Blue
    }
    
    # Get non-targeting control LFCs for background distribution
    ntc_lfcs = []
    for (condition, donor), df in all_results.items():
        ntc_guides = df[df['Gene'] == 'Non-Targeting Control']
        ntc_lfcs.extend(ntc_guides['Selected_lfc'].values)
    
    # Get overall LFC range for consistent x-axis (including both NTC and target genes)
    all_lfcs = []
    for (condition, donor), df in all_results.items():
        all_lfcs.extend(df['Selected_lfc'].values)
    
    lfc_min, lfc_max = np.percentile(all_lfcs, [1, 99])
    lfc_range = lfc_max - lfc_min
    x_min, x_max = lfc_min - lfc_range*0.1, lfc_max + lfc_range*0.1
    
    # Create bins for background distribution
    n_bins = 100
    bin_edges = np.linspace(x_min, x_max, n_bins + 1)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    bin_width = bin_edges[1] - bin_edges[0]
    
    # Calculate non-targeting control distribution
    if len(ntc_lfcs) > 0:
        ntc_lfc_hist, _ = np.histogram(ntc_lfcs, bins=bin_edges, density=True)
        # Normalize to 0-1 for background intensity
        ntc_lfc_hist_norm = ntc_lfc_hist / np.max(ntc_lfc_hist)
    else:
        print("Warning: No non-targeting control guides found!")
        ntc_lfc_hist_norm = np.zeros(len(bin_centers))
    
    for gene_idx, gene in enumerate(target_genes):
        ax = axes[gene_idx]
        
        # Create y-positions for each condition-donor combination
        # Loop through conditions first, then donors (so donors are adjacent)
        y_positions = []
        y_labels = []
        
        for i, condition in enumerate(conditions):
            for j, donor in enumerate(donors):
                y_pos = i * len(donors) + j
                y_positions.append(y_pos)
                y_labels.append(f"{condition}-{donor}")
        
        # Plot non-targeting control distribution as grayscale bars for each row
        for y_pos in y_positions:
            for bin_idx, (center, intensity) in enumerate(zip(bin_centers, ntc_lfc_hist_norm)):
                color_intensity = 1 - intensity * 0.7  # Lighter = higher density
                ax.barh(y_pos, bin_width, left=center - bin_width/2, 
                       height=bar_height, color=str(color_intensity), 
                       alpha=0.8, edgecolor='none')
        
        # Plot guides for this gene
        for i, condition in enumerate(conditions):
            for j, donor in enumerate(donors):
                y_pos = i * len(donors) + j
                
                # Get data for this condition-donor combination
                if (condition, donor) in all_results:
                    df = all_results[(condition, donor)]
                    gene_guides = df[df['Gene'] == gene]
                    
                    if len(gene_guides) > 0:
                        # Plot individual guides as thick vertical lines
                        for _, guide_row in gene_guides.iterrows():
                            lfc = guide_row['Selected_lfc']
                            
                            # Create thick vertical line (no text labels)
                            ax.plot([lfc, lfc], [y_pos - bar_height/2, y_pos + bar_height/2], 
                                   color=condition_colors[condition], linewidth=4, 
                                   alpha=0.9, solid_capstyle='round')
        
        # Formatting
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(-0.5, len(y_positions) - 0.5)
        ax.set_yticks(y_positions)
        ax.set_yticklabels(y_labels)
        ax.grid(True, alpha=0.3, axis='x')
        
        # Add gene name as title
        ax.set_ylabel(f'{gene}', fontsize=14, weight='bold', rotation=0, 
                     labelpad=50, ha='right', va='center')
        
        # Add vertical line at LFC = 0
        ax.axvline(x=0, color='black', linestyle='--', alpha=0.5, linewidth=1)
        
        # Remove top and right spines
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    # Set x-label only for bottom plot
    axes[-1].set_xlabel('log₂ Fold Change', fontsize=12, weight='bold')
    
    # Add legend
    legend_elements = []
    for condition, color in condition_colors.items():
        legend_elements.append(mpatches.Patch(color=color, label=condition))
    legend_elements.append(mpatches.Patch(color='gray', alpha=0.5, label='Non-targeting control density'))
    
    fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.98, 0.98))
    
    # Add title
    fig.suptitle('Guide LFC Distribution by Gene and Condition', fontsize=16, weight='bold')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.93, right=0.85)
    return fig

# Define your target genes
target_genes = ['CD7', 'KIAA0319L', 'RPP21', 'PTEN', 'RASA2']

# Create the plot
if all_results:
    fig = create_guide_distribution_bars(
        all_results=all_results,
        target_genes=target_genes,
        conditions=conditions,
        donors=donors,
        figsize=(14, 12),
        bar_height=0.8
    )
    
    plt.savefig("guide_distribution_bars.svg", format="svg", dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
import matplotlib.patches as mpatches
def create_guide_distribution_bars(all_results, target_genes, conditions, donors, 
                                 figsize=(12, 10), bar_height=0.8):
    """
    Create horizontal bar plots showing guide LFC distributions for target genes
    Each gene gets its own panel with bars for each condition-donor combination
    Background shows non-targeting control guide distribution
    """
    
    # Create figure with subplots for each gene
    fig, axes = plt.subplots(len(target_genes), 1, figsize=figsize, 
                            sharex=True, gridspec_kw={'hspace': 0.3})
    
    # Handle single gene case
    if len(target_genes) == 1:
        axes = [axes]
    
    # Define colors for each condition
    condition_colors = {
        'AAV6': '#FF6B6B',      # Red
        'Ark315': '#4ECDC4',    # Teal  
        'Ark312': '#45B7D1',    # Blue
    }
    
    # Get non-targeting control LFCs for background distribution
    ntc_lfcs = []
    for (condition, donor), df in all_results.items():
        ntc_guides = df[df['Gene'] == 'Non-Targeting Control']
        ntc_lfcs.extend(ntc_guides['Selected_lfc'].values)
    
    # Get overall LFC range for consistent x-axis (including both NTC and target genes)
    all_lfcs = []
    for (condition, donor), df in all_results.items():
        all_lfcs.extend(df['Selected_lfc'].values)
    
    lfc_min, lfc_max = np.percentile(all_lfcs, [1, 99])
    lfc_range = lfc_max - lfc_min
    x_min, x_max = lfc_min - lfc_range*0.1, lfc_max + lfc_range*0.1
    
    # Create bins for background distribution
    n_bins = 100
    bin_edges = np.linspace(x_min, x_max, n_bins + 1)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    bin_width = bin_edges[1] - bin_edges[0]
    
    # Calculate non-targeting control distribution
    if len(ntc_lfcs) > 0:
        ntc_lfc_hist, _ = np.histogram(ntc_lfcs, bins=bin_edges, density=True)
        # Normalize to 0-1 for background intensity
        ntc_lfc_hist_norm = ntc_lfc_hist / np.max(ntc_lfc_hist)
    else:
        print("Warning: No non-targeting control guides found!")
        ntc_lfc_hist_norm = np.zeros(len(bin_centers))
    
    for gene_idx, gene in enumerate(target_genes):
        ax = axes[gene_idx]
        
        # Create y-positions for each condition-donor combination
        # Loop through conditions first, then donors (so donors are adjacent)
        y_positions = []
        y_labels = []
        
        for i, condition in enumerate(conditions):
            for j, donor in enumerate(donors):
                y_pos = i * len(donors) + j
                y_positions.append(y_pos)
                y_labels.append(f"{condition}-{donor}")
        
        # Plot non-targeting control distribution as grayscale bars for each row
        for y_pos in y_positions:
            for bin_idx, (center, intensity) in enumerate(zip(bin_centers, ntc_lfc_hist_norm)):
                color_intensity = 1 - intensity * 0.7  # Lighter = higher density
                ax.barh(y_pos, bin_width, left=center - bin_width/2, 
                       height=bar_height, color=str(color_intensity), 
                       alpha=0.8, edgecolor='none')
        
        # Plot guides for this gene
        for i, condition in enumerate(conditions):
            for j, donor in enumerate(donors):
                y_pos = i * len(donors) + j
                
                # Get data for this condition-donor combination
                if (condition, donor) in all_results:
                    df = all_results[(condition, donor)]
                    gene_guides = df[df['Gene'] == gene]
                    
                    if len(gene_guides) > 0:
                        # Plot individual guides as thick vertical lines
                        for _, guide_row in gene_guides.iterrows():
                            lfc = guide_row['Selected_lfc']
                            
                            # Create thick vertical line (no text labels)
                            ax.plot([lfc, lfc], [y_pos - bar_height/2, y_pos + bar_height/2], 
                                   color=condition_colors[condition], linewidth=4, 
                                   alpha=0.9, solid_capstyle='round')
        
        # Formatting
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(-0.5, len(y_positions) - 0.5)
        ax.set_yticks(y_positions)
        ax.set_yticklabels(y_labels)
        ax.grid(True, alpha=0.3, axis='x')
        
        # Add gene name as title
        ax.set_ylabel(f'{gene}', fontsize=14, weight='bold', rotation=0, 
                     labelpad=50, ha='right', va='center')
        
        # Add vertical line at LFC = 0
        ax.axvline(x=0, color='black', linestyle='--', alpha=0.5, linewidth=1)
        
        # Remove top and right spines
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    # Set x-label only for bottom plot
    axes[-1].set_xlabel('log₂ Fold Change', fontsize=12, weight='bold')
    
    # Add legend
    legend_elements = []
    for condition, color in condition_colors.items():
        legend_elements.append(mpatches.Patch(color=color, label=condition))
    legend_elements.append(mpatches.Patch(color='gray', alpha=0.5, label='Non-targeting control density'))
    
    fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0, 0))
    
    # Add title
    fig.suptitle('Guide LFC Distribution by Gene and Condition', fontsize=16, weight='bold')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.93, right=0.85)
    return fig

# Define your target genes
target_genes = ['CD7', 'KIAA0319L', 'SLC35B2', 'ST3GAL4', 'RPL8','RPS25','RPP21','EIF6','VAV1','LCP2']

# Create the plot
if all_results:
    fig = create_guide_distribution_bars(
        all_results=all_results,
        target_genes=target_genes,
        conditions=conditions,
        donors=donors,
        figsize=(5, 10),
        bar_height=0.8
    )
    
    plt.savefig("guide_distribution_bars.svg", format="svg", dpi=300, bbox_inches='tight')
    plt.show()