In [None]:
# This script takes the Hypr-seq adata and the loom file as input
# And output an scRNA-seq data with detailed annotations

import scanpy as sc
import numpy as np
import pandas as pd

import argparse
import seaborn as sns
import sys
import loompy
import logging
import configparser
import anndata
import matplotlib.pyplot as plt
import re
import matplotlib.patches as mpatches
import os
import argparse
from pathlib import Path # For easier path handling
import anndata as ad

import copy
import scipy.sparse as sp

# Configure logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)

# Read and filter the Hypr-seq anndata
def filter_hypr(adata, probe_level=False):
    # Remove cells with less than 1000 UMIs per cell
    sc.pp.filter_cells(adata, min_counts=1000)
    # merge the probe name, making the cell by probe matrix to cell by gene matrix
    # Splitting at the underscore to separate gene names from probe numbers
    if not probe_level:
        gene_names = [name.split('_')[0] for name in adata.var_names]
    else:
        gene_names = [name for name in adata.var_names]

    # Add the gene names as a new column in the DataFrame of the .var slot
    adata.var['gene_name'] = gene_names
    # Convert the AnnData to a DataFrame for easier manipulation
    adata_df = pd.DataFrame(adata.X.T, index=adata.var_names, columns=adata.obs_names)

    # Use the gene names to sum the counts
    # Group by the new gene names and sum across columns (probes for the same gene)
    aggregated_data = adata_df.groupby(adata.var['gene_name']).sum()

    # Transpose back to original shape (samples as rows, genes as columns)
    aggregated_data = aggregated_data.T

    # Create new AnnData object with the aggregated data
    adata_aggregated = sc.AnnData(X=aggregated_data)
    # Copy the metadata from the original AnnData object
    adata_aggregated.obs = adata.obs.copy()
    # Optionally, copy over any relevant .uns data (unsupervised annotations, such as PCA, neighbors, etc.)
    #adata_aggregated.uns = adata.uns.copy()

    return adata_aggregated


# function reads the loomfile downloaded from Tapestri portal
def read_tapestri_loom(filename):
    """
    Read data from MissionBio's formatted loom file.

    Parameters
    ----------
    filename : str
        Path to the loom file (.loom)

    Returns
    -------
    anndata.AnnData
        An anndata object with the following layers:
        adata.X: GATK calls
        adata.layers['e']: reads with evidence of mutation
        adata.layers['no_e']: reads without evidence of mutation
    """
    loom_file = loompy.connect(filename)

    variant_names, amplicon_names, chromosome, location = (
        loom_file.ra['id'], loom_file.ra['amplicon'], loom_file.ra['CHROM'], loom_file.ra['POS'])

    barcodes = [barcode.split('-')[0] for barcode in loom_file.ca['barcode']]
    adata = anndata.AnnData(np.transpose(loom_file[:, :]), dtype=loom_file[:, :].dtype)
    adata.layers['e'] = np.transpose(loom_file.layers['AD'][:, :])
    adata.layers['no_e'] = np.transpose(loom_file.layers['RO'][:, :])
    adata.var_names = variant_names
    adata.obs_names = barcodes
    adata.varm['amplicon'] = amplicon_names
    adata.varm['chrom'] = chromosome
    adata.varm['loc'] = location

    loom_file.close()

    return adata

# Optional function, visualize the barcode distribution
def barcode_rank_plot(adata, minimum=0, xmax=None):

    # Sum the UMIs for each cell
    cell_umi_counts_all = adata.X.sum(axis=1)
    cell_idxs = np.argwhere(cell_umi_counts_all > minimum)
    cell_umi_counts = cell_umi_counts_all[cell_idxs]

    # Convert to numpy array if it's not already
    cell_umi_counts = np.array(cell_umi_counts).flatten()

    # Sort the UMI counts in descending order for the rank plot
    sorted_umi_counts = np.sort(cell_umi_counts)[::-1]

    plt.figure(figsize=(5, 4), dpi=150)
    sns.lineplot(x=range(1, len(sorted_umi_counts) + 1), y=sorted_umi_counts)
    plt.xlabel('Barcode Rank')
    plt.ylabel('UMI Count')
    plt.title('Cell Barcode Rank Plot')
    plt.yscale('log')  # Log scale for better visualization
    plt.xscale('log') # Log scale
    if xmax is not None:
        plt.xlim(0, xmax)
    plt.show()


# Function for finding the intersecting barcodes between the modalities
# We can drop the idea of mudata for simplity
def find_intersecting_and_filter(adata_hypr, adata_loom):
    """
    Find the intersecting barcode, reorder, and return new AnnData objects for both datasets.
    """
    # Find intersecting barcodes
    obs_1, obs_2 = adata_hypr.obs_names, adata_loom.obs_names
    cmn_barcodes, idx_1, idx_2 = np.intersect1d(obs_1, obs_2, return_indices=True)

    # Logging the number of barcodes
    logger.info(f"Found {len(obs_1)} barcodes in modality hypr seq")
    logger.info(f"Found {len(obs_2)} barcodes in modality loom file")
    logger.info(f"Found {len(idx_1)} intersecting barcodes")

    # Subset and reorder both datasets based on the intersecting indices
    adata_hypr_new = adata_hypr[idx_1, :]
    adata_loom_new = adata_loom[idx_2, :]

    # # Ensure the order of barcodes is the same in both datasets
    # adata_hypr_new = adata_hypr_new[cmn_barcodes, :]
    # adata_loom_new = adata_loom_new[cmn_barcodes, :]

    return adata_hypr_new, adata_loom_new
    

# Step 1, determine the germline mutation from the loom file
# We envision that if certain mutation appeared too many times in the dataset
# (e.g., more than 25% cells have the homo mutation, we call it homo germline)
def call_germline_mutations(adata_loom, cutoff=0.25):

    fraction_het = np.mean(adata_loom.X == 1, axis=0)
    het_germline = adata_loom.var_names[fraction_het > cutoff]

    fraction_hom = np.mean(adata_loom.X == 2, axis=0)
    hom_germline = adata_loom.var_names[fraction_hom > cutoff]

    return list(het_germline) + list(hom_germline)


# For AAV control, we may ignore mixed cells and bystander effect
def call_AAV_control_edits(adata_loom, config_path='AAV_config.ini'):
    """
    Input: Tapestri loom file and mutation configuration 
    Current we only accept 1 type of AAV control (should be a continuous region)
    Output: All the AAV edits found in the loom file. 

    This function will not modify adata_loom
    """
    # Read configuration settings
    config = configparser.ConfigParser()
    config.read(config_path)
    
    chrom_name = config.get('AAV_control', 'chrom_name')
    start = config.getint('AAV_control', 'start')
    end = config.getint('AAV_control', 'end')
    variant_alleles = config.get('AAV_control', 'variant_alleles')
    
    # Parsing variant names
    variant_names = np.asarray(adata_loom.var_names.values)
    chrom = [name.split(':')[0] for name in variant_names]
    loc = [int(name.split(':')[1]) for name in variant_names]
    edit_type = [name.split(':')[2] for name in variant_names]

    control_editing = []
    complement_rule = {'C': 'G', 'G': 'C', 'A': 'T', 'T': 'A'}
    try:
        reverse_variant_alleles = complement_rule[variant_alleles[0]] + "/" + complement_rule[variant_alleles[-1]]
    except:
        raise ValueError(f"The variant alleles is not supported. Current variant alleles is {variant_alleles}")
    # Identify control editing cells
    for i in range(len(chrom)):
        if chrom[i] == chrom_name and start <= loc[i] < end+1:
            if edit_type[i] in [variant_alleles, reverse_variant_alleles]:
                control_editing.append(variant_names[i])

    return control_editing


def call_AAV_del_edits(adata_loom, config_path):
    """
    Input: Tapestri loom file and mutation configuration 
    Current we only accept 1 type of AAV control (should be a continuous region)
    Output: All the AAV edits found in the loom file. 
    """
    # Read configuration settings
    config = configparser.ConfigParser()
    config.read(config_path)
    
    chrom_name = config.get('AAV_del', 'chrom_name')
    start = config.getint('AAV_del', 'start')
    end = config.getint('AAV_del', 'end')
    
    # Parsing variant names
    variant_names = np.asarray(adata_loom.var_names.values)
    chrom = [name.split(':')[0] for name in variant_names]
    loc = [int(name.split(':')[1]) for name in variant_names]
    edit_type = [name.split(':')[2] for name in variant_names]

    del_editing = []
    
    # Identify  editing cells
    for i in range(len(chrom)):
        if chrom[i] == chrom_name and start <= loc[i] < end+1:
            if edit_type[i][-1] == "*":
                del_editing.append(variant_names[i])

    return del_editing


def call_AAV_cells(adata_hypr, adata_loom, AAV_editing):
    """
    Annotate the hypr adata (scRNA-seq data) by the AAV_edit found in call_AAV_edits.
    We do not perform annotation in the function to ensure consistency
    """
    all_idx = []
    for edit in AAV_editing:
        # get the idx of the edit we found
        idx = adata_loom.var_names.get_loc(edit)
        # if the value is 1/2 (hete/homo), we annotate the cell as control edit cell
        het_idx, hom_idx = [np.flatnonzero(adata_loom.X[:, idx] == i) for i in (1, 2)]
        AAV_edit_idx = np.concatenate([het_idx, hom_idx])
        AAV_cells_barcode_loom = adata_loom.obs_names[AAV_edit_idx]
        # find the AAV cells index in that of the adata_hypr
        _, idx1, _ = np.intersect1d(adata_hypr.obs_names, AAV_cells_barcode_loom, return_indices=True)
        all_idx.append(idx1)
        
    return np.unique(np.concatenate(all_idx))



def get_nearby_variants(
    variant_names, # List of all variant in the loom file
    target_loc, # loci of interest
    germline_amplicons, # List of Germline mutations 
    window_size # window size
):
    """
        Get nearby variants within a window size of the target_loc.
        Args:
            variant_names (list): List of variant names.
            target_loc (str): Target loci (e.g. chr1:123456:A/G).
            germline_amplicons (list): List of germline amplicons.
            window_size (int): Window size.
        Returns:
            list: List of nearby variants.
    """
    chrom_t, loc_t, _ = target_loc.split(":")
    nearby_variants = []
    for variant in variant_names:
        chrom, loc, edit_type = variant.split(":")
        is_valid_edit_type = not edit_type.endswith("*") and re.match(
            r"^[A-Za-z]/[A-Za-z]$", edit_type
        )
        is_not_germline = variant not in germline_amplicons
        if (
            chrom == chrom_t
            and (int(loc_t) - window_size) <= int(loc) <= (int(loc_t) + window_size)
            and is_valid_edit_type
            and is_not_germline
        ):
            nearby_variants.append(variant)

    return nearby_variants



def plot_mutant_types(target_loc, new_matrix, mt_type, snp_name, save_path, index=""):
    """
    Plot the nearby genotype heatmap
    """
    mutation_counts = new_matrix.apply(np.count_nonzero, axis=0)
    plt.figure(figsize=(10, 6))
    bars = plt.bar(mutation_counts.index, mutation_counts.values)
    for i, idx in enumerate(mutation_counts.index):
        plt.text(
            bars[i].get_x() + bars[i].get_width() / 2,
            bars[i].get_height(),
            str(mutation_counts.values[i]),
            ha="center",
            va="bottom",
        )
    if target_loc in mutation_counts.index:
        bars[mutation_counts.index.tolist().index(target_loc)].set_color("r")
    red_patch = mpatches.Patch(color="red", label="Target Loci")
    plt.legend(handles=[red_patch])

    # we need to rename the target_loc to ensure we do not introduce extra path problems
    target_loc = target_loc.replace("/", "-")


    plot_title = (
        f"Mutation Counts: {mt_type} - {target_loc} ({index}) with total cells {len(new_matrix)}"
    )
    plt.title(plot_title)
    plt.xticks(rotation=90)
    plt.xlabel("")
    plt.ylabel("Number of Mutated Cells")

    save_index = f"{snp_name}_{mt_type}_mutant_counts_{index}.pdf"
    plt.savefig(os.path.join(save_path, save_index), bbox_inches="tight")
    plt.close()

    # Plot a cluster map of matrix but cluster rows only.
    sns.clustermap(
        new_matrix,
        cmap="YlGnBu",
        row_cluster=True,
        col_cluster=False,
        figsize=(10, 10),
    )
    plt.title(
        f"Mutation Matrix: {mt_type} - {snp_name} aka {target_loc} with shape {new_matrix.shape}"
    )
    plt.xlabel("Variant")
    plt.ylabel("Cell Barcode")
    save_index = f"{snp_name}_{mt_type}_mutant_matrix_{index}.pdf"
    plt.savefig(os.path.join(save_path, save_index), bbox_inches="tight")
    plt.close()




def process_bystander_cells(target_loc, cells, nearby_variants, adata_loom):
    """
    Identify which cells have bystander editing.
    Args:
        target_loc (str): Target loci (e.g. chr1:123456:A/G).
        cells (list): List of cells (that supposedly have the mutation).
        nearby_variant: adata that includes only the nearby cells
    Returns:
        list: List of pure cells.
        pd.DataFrame: New matrix.
        pd.DataFrame: Matrix with false amplicons.
    """
    nearby_adata = adata_loom[cells, nearby_variants].copy()
    new_matrix = pd.DataFrame(
        nearby_adata.X, index=nearby_adata.obs_names, columns=nearby_adata.var_names
    )
    # Change all 3 to 0 in the new matrix, but excluding the target loci column.
    for col in new_matrix.columns:
        if col != target_loc:
            new_matrix[col] = new_matrix[col].apply(
                lambda val: 0 if val == 3 else val
            )

    # Remove the target loci column to generate a bystander position only matrix.
    matrix_without_target_loci = new_matrix.loc[
        :, new_matrix.columns != target_loc
    ].copy()
    bystander_sum_for_every_cell = matrix_without_target_loci.sum(axis=1)
    # Remove cells that have a row sum of 0, or there's no bystander editing.
    matrix_without_target_loci = matrix_without_target_loci.loc[
        bystander_sum_for_every_cell != 0
    ]

    # If there's no bystander editing, return the original cells.
    if len(matrix_without_target_loci) == 0:
        # Set the bystander as the empty set
        return set(cells), set(), new_matrix

    # Everything that's left is a bystander cell.
    pure_cells = set(cells) - set(matrix_without_target_loci.index)
    return pure_cells, set(matrix_without_target_loci.index), new_matrix



def call_single_loci_cells(
    target_loc, # the loci of interest
    adata_loom, # the loom matrix to retrieve info
    window_size, # the window size 
    germline_amplicons, # The germline mutations
    snp_name, # the name of the loci
    plot_path, # the path to save the figures
):
    """
    Process a single target loci. 
    We will annotate all cells that include the loci
    The values represent the mutant type:
        0: Background
        1: Pure heterozygous
        2: Heterozygous with bystander editing
        3: Pure homozygous
        4: Homozygous with bystander editing
    """
    # Find the idx of the loci
    idx = adata_loom.var_names.get_loc(target_loc)
    # Annotate cells that are background, het, and hom based on the loom matrix
    bkg_cells, het_cells, hom_cells = [
            adata_loom.obs_names[np.flatnonzero(adata_loom.X[:, idx] == i)] for i in (0, 1, 2)
            ]
    # Find all nearby variant of the give loci
    variant_names = adata_loom.var_names
    nearby_variants = get_nearby_variants(
            variant_names, target_loc, germline_amplicons, window_size
    )

    # ignore synonymous mutations
    if target_loc == "chr2:190986868:C/T":
        if 'chr2:190986872:C/T' in nearby_variants:
            nearby_variants.remove("chr2:190986872:C/T")

    if target_loc == "chr2:190986872:C/T":
        if 'chr2:190986868:C/T' in nearby_variants:
            nearby_variants.remove("chr2:190986868:C/T")

    if target_loc == "chr2:190997872:C/T":
        if 'chr2:190997873:C/T' in nearby_variants:
            nearby_variants.remove("chr2:190997873:C/T")
    
    if target_loc == "chr2:190997873:C/T":
        if 'chr2:190997872:C/T' in nearby_variants:
            nearby_variants.remove("chr2:190997872:C/T")
    
    if target_loc == "chr9:5054789:G/A":
        if 'chr9:5054790:G/A' in nearby_variants:
            nearby_variants.remove("chr9:5054790:G/A")

    if target_loc == "chr9:5054790:G/A":
        if 'chr9:5054789:G/A' in nearby_variants:
            nearby_variants.remove("chr9:5054789:G/A")
    
    if target_loc == "chr21:33421686:G/A":
        if 'chr21:33421679:C/T' in nearby_variants:
            nearby_variants.remove("chr21:33421679:C/T")
    
    if target_loc == "chr21:33421679:C/T":
        if 'chr21:33421686:G/A' in nearby_variants:
            nearby_variants.remove("chr21:33421686:G/A")


    # the following ignore nearby variants are to remove synonymous mutations
    if target_loc == "chr1:64864885:G/A":
        if 'chr1:64864886:G/A' in nearby_variants:
            nearby_variants.remove("chr1:64864886:G/A")
    
    if target_loc == "chr1:64873414:T/C":
        if 'chr1:64873415:T/C' in nearby_variants:
            nearby_variants.remove("chr1:64873415:T/C")

    if target_loc == "chr1:64879107:C/T":
        if 'chr1:64879105:C/T' in nearby_variants:
            nearby_variants.remove("chr1:64879105:C/T")
    
    if target_loc == "chr2:190995142:G/A":
        if 'chr2:190995144:G/A' in nearby_variants:
            nearby_variants.remove("chr2:190995144:G/A")
        if 'chr2:190995141:G/A' in nearby_variants:
            nearby_variants.remove("chr2:190995141:G/A")

    if target_loc == "chr2:190999673:T/C":
        if 'chr2:190999675:T/C' in nearby_variants:
            nearby_variants.remove("chr2:190999675:T/C")
        if 'chr2:190999669:T/C' in nearby_variants:
            nearby_variants.remove("chr2:190999669:T/C")

    if target_loc == "chr5:132486824:G/A":
        if 'chr5:132486825:G/A' in nearby_variants:
            nearby_variants.remove("chr5:132486825:G/A")

    if target_loc == "chr6:137206214:A/G":
        if 'chr6:137206215:A/G' in nearby_variants:
            nearby_variants.remove("chr6:137206215:A/G")

    if target_loc == "chr6:137206249:A/G":
        if 'chr6:137206248:A/G' in nearby_variants:
            nearby_variants.remove("chr6:137206248:A/G")
        if 'chr6:137206251:A/G' in nearby_variants:
            nearby_variants.remove("chr6:137206251:A/G")

    if target_loc == "chr9:5078360:A/G":
        if 'chr9:5078362:A/G' in nearby_variants:
            nearby_variants.remove("chr9:5078362:A/G")

    if target_loc == "chr9:5089702:G/A":
        if 'chr9:5089703:G/A' in nearby_variants:
            nearby_variants.remove("chr9:5089703:G/A")

    if target_loc == "chr21:33421679:C/T":
        if 'chr21:33421675:C/T' in nearby_variants:
            nearby_variants.remove("chr21:33421675:C/T")
        



    het_pure, het_bystander, het_matrix = process_bystander_cells(
            target_loc, het_cells, nearby_variants, adata_loom
    )
    hom_pure, hom_bystander, hom_matrix = process_bystander_cells(
            target_loc, hom_cells, nearby_variants, adata_loom
    )
    # het_matrix.fillna(0)
    # hom_matrix.fillna(0)
    try:
        if len(list(het_bystander))>1:
            plot_mutant_types(target_loc, het_matrix.fillna(0), "Heterozygous", snp_name, plot_path, "1-0")
        if len(list(hom_bystander))>1:
            plot_mutant_types(target_loc, hom_matrix.fillna(0), "Homozygous", snp_name, plot_path, "1-0")
    except:
        import pdb; pdb.set_trace()
        raise ValueError("wrong figure plotting again")

    return list(het_bystander), list(hom_bystander), list(het_pure), list(hom_pure)

    # adata_hypr.loc[hom_pure, "genotype"] = f"{target_loc}_homo_pure"
    # adata_hypr.loc[het_pure, "genotype"] = f"{target_loc}_hete_pure"


def annotate_cells(adata):
    """
    Annotates each cell in adata.obs based on the SNP, control, and del columns, according to specified rules.
    
    Parameters:
    - adata: AnnData object containing the obs DataFrame with binary columns for each condition.
    
    Returns:
    - Annotations are stored in a new column 'annotation' in adata.obs.
    """
    
    # List of condition columns, based on the presence of SNPs, control, and del columns in adata.obs
    snp_columns = [col for col in adata.obs.columns if any(sub in col for sub in ["het_bystander", "het_pure", "hom_bystander", "hom_pure"])]
    control_column = "AAV_control"
    # del_column = "AAV_del"
    
    # Define the function for annotating a single cell
    def annotate_row(row):
        # Check SNP-related rules
        snp_active = row[snp_columns].sum()  # Count number of active SNP columns (True values)
        
        if snp_active > 1:  # More than one SNP column is True
            return "mixed"
        
        if row[control_column] and snp_active == 1:  # SNP and control active
            # Find which SNP column is active
            snp_name = row[snp_columns][row[snp_columns] == True].index[0]
            return "mixed"
        
        
        if row[control_column] and snp_active == 0:  # Only control active
            return "AAV_control"
        
        
        if snp_active == 1:  # Single SNP active, no control or del
            snp_name = row[snp_columns][row[snp_columns] == True].index[0]
            return snp_name
        
        if snp_active == 0 and not row[control_column]:
            return "unedited"
        
        # Default case, though we should not reach here
        return "unknown"
    # import pdb; pdb.set_trace()
    # Apply the annotation function to each row in a vectorized manner
    adata.obs['genotype_annotation'] = adata.obs.apply(annotate_row, axis=1)


def annotate_cells_v2(adata):
    """
    Annotates each cell in adata.obs based on SNP columns, applying special rules for a 
    predefined set of "bad amplicon" genotypes.

    Parameters:
    - adata: AnnData object containing the obs DataFrame with binary columns for each genotype.

    Returns:
    - Annotations are stored in a new column 'genotype_annotation' in adata.obs.
    """
    
    # Define your set of "bad amplicon" genotypes here
    # These are the SNPs that will follow the new rules.
    bad_amplicon_genotype = [
        'STAT1-A402T', 
        'IRF1-W11R',
        'IFNGR1-E164K',
        'IFNGR1-NC2',
        'JAK2-G281S',
        'JAK2-G281D'
    ]
    # add het_pure, home_pure, het_bystander, hom_bystander
    bad_amplicon_genotype = [f"{snp}_{suffix}" for snp in bad_amplicon_genotype for suffix in ["het_pure", "hom_pure", "het_bystander", "hom_bystander"]]

    # Identify all SNP-related columns from your AnnData object
    snp_columns = [col for col in adata.obs.columns if any(sub in col for sub in ["het_bystander", "het_pure", "hom_bystander", "hom_pure"])]
    control_column = "AAV_control"

    # Separate the SNP columns into "bad" and "good" for easier processing
    bad_snp_columns = [col for col in snp_columns if col in bad_amplicon_genotype]
    good_snp_columns = [col for col in snp_columns if col not in bad_amplicon_genotype]

    def annotate_row(row):
        # Determine the state of different SNP types and controls
        is_aav_active = row[control_column]
        active_bad_snps = row[bad_snp_columns][row[bad_snp_columns] == True].index.tolist()
        active_good_snps = row[good_snp_columns][row[good_snp_columns] == True].index.tolist()
        
        num_active_bad = len(active_bad_snps)
        num_active_good = len(active_good_snps)

        # --- Rule Set 1: Apply special logic if any "bad amplicon" SNP is active ---
        if num_active_bad > 0:
            
            # Rule 2: If a bad SNP and one other "good" SNP are active, annotate as the "good" SNP.
            if num_active_good == 1:
                return active_good_snps[0]
                
            # Rule 3: If AAV_control is also active, call it AAV_control.
            if is_aav_active:
                return "AAV_control"
            
            # Rule 1: If only one SNP is active and it's a "bad" one, annotate as that bad SNP.
            if num_active_good == 0 and num_active_bad == 1:
                return active_bad_snps[0]
            
            # Edge Case: If a bad SNP and multiple "good" SNPs are active, or multiple bad SNPs are active.
            # This defaults to "mixed" as the outcome is ambiguous.
            return "mixed"

        # --- Rule Set 2: Fallback to original logic if NO "bad amplicon" SNPs are active ---
        else: # num_active_bad == 0
            # More than one "good" SNP is active
            if num_active_good > 1:
                return "mixed"
            
            # One "good" SNP is active
            if num_active_good == 1:
                # # If AAV_control is also active, it's mixed
                # if is_aav_active:
                #     return "mixed"
                # # Otherwise, it's just the SNP itself
                # else:
                return active_good_snps[0]
            
            # No "good" SNPs are active (num_active_good == 0)
            if num_active_good == 0:
                if is_aav_active:
                    return "AAV_control"
                else:
                    return "unedited"
        
        # Default fallback, should not be reached with the logic above
        return "unknown"

    # Apply the annotation function to each row
    adata.obs['genotype_annotation'] = adata.obs.apply(annotate_row, axis=1)




def annotate_genotype(adata_loom, 
                    adata_hypr, 
                    config_path,
                    save_path="./",
                    plot_path="./"):
    """
    Annotate the adata_hypr, such that in the adata_hypr.obs, we have:
    Homo pure for SNP 1, 2, 3, …
    Homo bystander for SNP 1,2, 3, …
    Hete pure for SNP 1, 2, 3, …
    Hete bystander for SNP 1, 2, 3…
    Mixed cells
    AAVs: AAV_control, AAV_del
    Unedited. 
    """

    # Perform the annotation step. We need:
    # 1. Determine the germline mutations
    # 2. Annotate the AAV_control cells
    # 3. Annotate the AAV_del_control cells
    # 4. For each SNP we are interested in, find if it has bystander effect
    # 5. Determine the bystander mutations
    # 6.  Annotate all mixed cells. For example, for 2 SNPs we are interested in, 
    #  if they are co-edited, we need to count the number. We need to annotate it as mixed.
    

    # Step 1, determine the germline mutations
    germline_amplicons = call_germline_mutations(adata_loom, cutoff=0.25)

    # Step 2, annotate the AAV_control cells
    # 2.1 Find the control editing locis 
    AAV_control_editing = call_AAV_control_edits(adata_loom, config_path=config_path)
    # 2.2 Add a column to the adata_hypr obs and then perform the annotation only here
    adata_hypr.obs["AAV_control"] = False
    AAV_control_cell_idx = call_AAV_cells(adata_hypr, adata_loom, AAV_editing=AAV_control_editing)
    adata_hypr.obs.iloc[AAV_control_cell_idx, adata_hypr.obs.columns.get_loc("AAV_control")] = True

    # Step 3, annotate the AAV_del_control_cells
    # 3.1 Find the del editing locis
    try:
        AAV_del_editing = call_AAV_del_edits(adata_loom, config_path=config_path)
        # 3.2 Add a column to the adata_hypr and then perform the annotation only here
        adata_hypr.obs["AAV_del"] = False
        AAV_del_cell_idx = call_AAV_cells(adata_hypr, adata_loom, AAV_editing=AAV_del_editing)
        adata_hypr.obs.iloc[AAV_del_cell_idx, adata_hypr.obs.columns.get_loc("AAV_del")] = True
    except:
        pass

    # Step 4/5, Enumerate all other variants we are interested in
    # Initialize the config parser
    config = configparser.ConfigParser()
    # Read the config file
    config.read(config_path)
    # Retrieve and return all section names
    sections = config.sections()
    exclude_sections = ["AAV_control", "AAV_del"]
    SNPs = [section for section in sections if section not in exclude_sections]

    # Each snp is the section name we defined in the config file
    # Perform the annotation here
    annotation = ["het_bystander", "het_pure", "hom_bystander", "hom_pure"]

    for snp in SNPs:
        chrom_name = config.get(snp, 'chrom_name')
        locus = config.getint(snp, 'locus')
        alleles = config.get(snp, 'variant_alleles')
        complement_rule = {'C': 'G', 'G': 'C', 'A': 'T', 'T': 'A'}
        try:
            complement_variant_alleles = complement_rule[alleles[0]] + "/" + complement_rule[alleles[-1]]
        except:
            raise ValueError(f"The variant alleles is not supported. Current variant alleles is {alleles}")
        # Transfer back to the original version
        target_loc = ":".join([chrom_name, str(locus), alleles])
        complement_target_loc = ":".join([chrom_name, str(locus), complement_variant_alleles])
        
        # Find which allele is the one we want, the original version or the complement one
        if target_loc in adata_loom.var_names:
            target_loc = target_loc
        elif complement_target_loc in adata_loom.var_names:
            target_loc = complement_target_loc
        else:
            raise ValueError(f"Both the mutation {target_loc} and the reversed one {complement_target_loc} is not found in the loom file")

        # if snp == "JAK2-R683G":
        #     import pdb; pdb.set_trace()
        # Get the index of the cells
        het_bystander, hom_bystander, het_pure, hom_pure = call_single_loci_cells(target_loc, adata_loom, window_size=10, 
            germline_amplicons=germline_amplicons, snp_name = snp, plot_path=plot_path)
        

        for i, idx_i in enumerate([het_bystander, het_pure, hom_bystander, hom_pure]):
            adata_hypr.obs[f"{snp}_{annotation[i]}"] = False
            adata_hypr.obs.loc[idx_i, f"{snp}_{annotation[i]}"] = True

        # adata_hypr.obs[f"{snp}"] = adata_hypr.obs[f"{snp}_het_all"] | adata_hypr.obs[f"{snp}_hom_all"]
    # import pdb; pdb.set_trace()
    # Step 7, for the cells that has genotypes, we need to ensure that they are not mixed cells
    annotate_cells_v2(adata_hypr)

    return adata_hypr, adata_loom
    



def main():
    # Create the parser
    parser = argparse.ArgumentParser(description="Filter, Intersect, Merge, and Annotate genotype and phenotype data from multiple experiments.")

    # --- Define Input Datasets ---
    # This structure defines the paths and metadata for your 5 datasets.
    # MODIFY THIS SECTION with your actual file paths and desired batch names.
    # Assumes a base directory structure, adjust as needed.
    # You could also load this from a metadata CSV file.
    datasets = {
        'ifng_rep1': {
            'loom_path': '/mnt/data/project/25_04_29_Figure3_reanalysis/data_raw/tapestri_loom/v3/MDM_DSIFNG_IFNg_gDNA1_v3.cells.loom',
            'hypr_path': '/mnt/data/project/25_04_29_Figure3_reanalysis/data_raw/hypr_probe_matrix/MDM_DSIFNG_IFNg_HyPR1.txt', # Or other hypr format
            'condition': 'IFNG',
            'batch': 'rep1',
        },
        'ifng_rep2': {
            'loom_path': '/mnt/data/project/25_04_29_Figure3_reanalysis/data_raw/tapestri_loom/v3/MDM_DSIFNG_IFNg_gDNA2_v3.cells.loom',
            'hypr_path': '/mnt/data/project/25_04_29_Figure3_reanalysis/data_raw/hypr_probe_matrix/MDM_DSIFNG_IFNg_HyPR2.txt',
            'condition': 'IFNG',
            'batch': 'rep2',
        },
        'ifng_rep3': {
            'loom_path': '/mnt/data/project/25_04_29_Figure3_reanalysis/data_raw/tapestri_loom/v3/MDM_DSIFNG_IFNg_gDNA3_v3.cells.loom',
            'hypr_path': '/mnt/data/project/25_04_29_Figure3_reanalysis/data_raw/hypr_probe_matrix/MDM_DSIFNG_IFNg_HyPR3.txt',
            'condition': 'IFNG',
            'batch': 'rep3',
        },
        'control_rep1': {
            'loom_path': '/mnt/data/project/25_04_29_Figure3_reanalysis/data_raw/tapestri_loom/v3/MDM_DSIFNG_Ctrl_gDNA1_v3.cells.loom',
            'hypr_path': '/mnt/data/project/25_04_29_Figure3_reanalysis/data_raw/hypr_probe_matrix/MDM_DSIFNG_Ctrl_HyPR1.txt',
            'condition': 'PBS',
            'batch': 'rep4', # Note: Batch names can overlap if conditions differ
        },
        'control_rep2': {
            'loom_path': '/mnt/data/project/25_04_29_Figure3_reanalysis/data_raw/tapestri_loom/v3/MDM_DSIFNG_Ctrl_gDNA2_v3.cells.loom',
            'hypr_path': '/mnt/data/project/25_04_29_Figure3_reanalysis/data_raw/hypr_probe_matrix/MDM_DSIFNG_Ctrl_HyPR2.txt',
            'condition': 'PBS',
            'batch': 'rep5',
        },
    }

    # Add arguments for common paths and settings
    parser.add_argument('--config_path', type=str, required=True, help='The path to the variant configuration file (INI format)')
    parser.add_argument('--save_path', type=str, required=True, help='The base path to save annotated anndata files')
    parser.add_argument('--plot_path', type=str, required=True, help='The base path to save genotype clustermap and other figures')
    parser.add_argument('--probe_level', type=int, required=True, help="1 for probe level matrix, other numbers for gene_level")

    # Parse the arguments
    args = parser.parse_args()

    config_path = args.config_path
    save_path = Path(args.save_path) # Use pathlib
    plot_path = Path(args.plot_path) # Use pathlib
    probe_level = (args.probe_level == 1)

    # Create output directories if they don't exist
    save_path.mkdir(parents=True, exist_ok=True)
    plot_path.mkdir(parents=True, exist_ok=True)

    # --- Process and Filter Individual Datasets ---
    processed_hypr_list = []
    processed_loom_list = []

    print("--- Processing Individual Datasets ---")
    for j, (sample_name, meta) in enumerate(datasets.items()):
        print(f"\nProcessing sample: {sample_name}")
        loom_path = Path(meta['loom_path'])
        hypr_path = Path(meta['hypr_path'])

        if not loom_path.is_file() or not hypr_path.is_file():
            print(f"  Warning: Input file(s) not found for {sample_name}. Skipping.")
            continue

        try:
            # Read data
            adata_loom_orig = read_tapestri_loom(loom_path)
            adata_hypr_orig = sc.read(hypr_path)
            print(f"  Read loom: {adata_loom_orig.shape}")
            print(f"  Read hypr: {adata_hypr_orig.shape}")

            # Filter hypr
            adata_hypr_filtered = filter_hypr(adata_hypr_orig, probe_level=probe_level)
            print(f"  Filtered hypr: {adata_hypr_filtered.shape}")

            # Filter and intersect
            adata_hypr_intersect, adata_loom_intersect = find_intersecting_and_filter(adata_hypr_filtered, adata_loom_orig)


            # further intersecting the barcodes with the anders' version
            # anders_hypr = sc.read("/mnt/data/project/25_04_29_Figure3_reanalysis/processed_data/merged_anders/Annotated_phenotype_hypr_seq_MERGED.h5ad")
            # obs_names_of_interest = [obs.split("-")[0] for obs in anders_hypr.obs_names]
            # # filter the hypr adata to only keep the barcodes that are in the anders_hypr
            # adata_hypr_intersect = adata_hypr_intersect[adata_hypr_intersect.obs_names.isin(obs_names_of_interest), :].copy()
            # adata_loom_intersect = adata_loom_intersect[adata_hypr_intersect.obs_names, :].copy()
            
            if adata_hypr_intersect.shape[0] == 0 or adata_loom_intersect.shape[0] == 0:
                 print(f"  Skipping {sample_name} due to zero intersecting cells after filtering.")
                 continue

            print(f"  Intersected hypr: {adata_hypr_intersect.shape}")
            print(f"  Intersected loom: {adata_loom_intersect.shape}")

            # Add metadata BEFORE concatenation
            adata_hypr_intersect.obs['condition'] = meta['condition']
            adata_hypr_intersect.obs['batch'] = meta['batch']
            adata_hypr_intersect.obs['sample_name'] = sample_name # Keep original sample name

            adata_loom_intersect.obs['condition'] = meta['condition']
            adata_loom_intersect.obs['batch'] = meta['batch']
            adata_loom_intersect.obs['sample_name'] = sample_name # Keep original sample name

            # Append to lists for merging
            processed_hypr_list.append(adata_hypr_intersect)
            processed_loom_list.append(adata_loom_intersect)
            print(f"  Successfully processed and filtered {sample_name}.")

        except Exception as e:
            print(f"  Error processing sample {sample_name}: {e}")

    # --- Merge Datasets ---
    if not processed_hypr_list or not processed_loom_list:
        print("\nError: No datasets successfully processed. Exiting.")
        return

    print("\n--- Merging Processed Datasets ---")
    try:
        # Concatenate hypr AnnDatas
        adata_hypr_merged = ad.concat(
            processed_hypr_list,
            join='outer',  # Keep all genes/probes, fill missing with 0 or NaN
            label='sample_key', # Column name to store original list index (0, 1, 2...)
            index_unique='-', # Make cell barcodes unique (e.g., sample_name-barcode)
            merge='unique' # How to merge .var/.obs data (use 'unique' or 'same' if appropriate)
        )
        # Optional: Refill sample_name based on prefix if index_unique was used effectively
        # adata_hypr_merged.obs['sample_name'] = [idx.split('-')[0] for idx in adata_hypr_merged.obs_names]
        print(f"Merged hypr data shape: {adata_hypr_merged.shape}")
        print(f"Merged hypr obs columns: {adata_hypr_merged.obs.columns.tolist()}")

        # Concatenate loom AnnDatas
        adata_loom_merged = ad.concat(
            processed_loom_list,
            join='outer', # Keep all mutations/sites
            label='sample_key',
            index_unique='-',
            merge='unique'
        )
        # Ensure loom data is aligned to merged hypr data
        adata_loom_merged = adata_loom_merged[adata_hypr_merged.obs_names, :].copy()
        print(f"Merged loom data shape: {adata_loom_merged.shape}")
        print(f"Merged loom obs columns: {adata_loom_merged.obs.columns.tolist()}")

        # Check alignment
        if not all(adata_hypr_merged.obs_names == adata_loom_merged.obs_names):
             raise ValueError("Merged loom and hypr AnnData objects are not perfectly aligned by cell barcode.")

    except Exception as e:
        print(f"Error during AnnData concatenation: {e}")
        return

    # --- Annotate Merged Data ---
    print("\n--- Annotating Merged Data ---")
    try:
        # Annotate the merged data using the genotype info from merged loom
        # The annotate_genotype function needs to handle merged data correctly
        adata_hypr_annotated, adata_loom_annotated = annotate_genotype(
            adata_loom=adata_loom_merged,
            adata_hypr=adata_hypr_merged,
            config_path=config_path,
            save_path=str(save_path), # Pass as string
            plot_path=str(plot_path)  # Pass as string
        )
    except Exception as e:
        print(f"Error during annotation: {e}")
        return

    # --- Save Final Annotated Data ---
    print("\n--- Saving Final Annotated Data ---")
    try:
        if probe_level:
            hypr_name = save_path / "Annotated_phenotype_hypr_seq_probe_MERGED.h5ad"
        else:
            hypr_name = save_path / "Annotated_phenotype_hypr_seq_MERGED.h5ad"
        loom_name = save_path / "Annotated_genotype_loom_MERGED.h5ad"

        print(f"Saving annotated hypr data to: {hypr_name}")
        sc.write(hypr_name, adata_hypr_annotated)

        adata_loom_annotated.X = sp.csr_matrix(adata_loom_annotated.X)  # Ensure sparse matrix format
        print(f"Saving annotated loom data to: {loom_name}")
        sc.write(loom_name, adata_loom_annotated)

        print("Successfully saved final annotated files.")
    except Exception as e:
        print(f"Error saving final files: {e}")



if __name__ == "__main__":
    # default config path is ./variant_config.ini
    main() 

In [None]:
# Next, assign donors
import scanpy as sc
import anndata as ad
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
from scipy.sparse import issparse, csr_matrix
import matplotlib.pyplot as plt

# --- Configuration ---
# adata_dna_file = 'path/to/your/dna_anndata.h5ad' # Path to your DNA AnnData file
# OR, if adata_dna is already in memory, you can pass it directly to a wrapper function.

# Parameters for germline variant selection
min_cell_frequency_for_germline = 0.25  # Variant must be non-zero in at least 25% of cells

# Clustering parameters
n_donors_to_cluster = 2
n_pcs_for_clustering = 10 # Number of PCs to use for K-Means (adjust based on variance)
kmeans_random_state = 42  # For reproducibility

# Output column name for donor labels
donor_cluster_col = 'inferred_donor_cluster'

def cluster_donors_from_genotypes(adata_dna_input):
    """
    Infers donor clusters from a genotype AnnData object.

    Args:
        adata_dna_input (sc.AnnData): AnnData object with .X as genotype matrix (cells x variants).
                                      Values are expected to be 0 (WT), 1 (Het), 2 (Hom), 3 (Unknown).

    Returns:
        sc.AnnData: The input AnnData object with a new column in .obs
                    (donor_cluster_col) indicating the inferred donor cluster.
                    Returns None if critical errors occur.
    """
    print("--- Starting Donor Clustering by Frequent Variants ---")
    
    # Work on a copy to avoid modifying the original object passed to the function
    adata_dna = adata_dna_input.copy()

    # --- 1. Preprocess Genotype Matrix (adata_dna.X) ---
    print("\nStep 1: Preprocessing genotype matrix...")
    
    # Ensure .X is not empty
    if adata_dna.X is None or adata_dna.shape[0] == 0 or adata_dna.shape[1] == 0:
        print("Error: adata_dna.X is empty or has zero dimensions.")
        return None

    # Handle sparse vs. dense matrix for filling NaNs and replacing 3s
    if issparse(adata_dna.X):
        print("  Matrix is sparse. Converting to CSR for modification.")
        adata_dna.X = adata_dna.X.tocsr()
        
        # Fill NaNs with 0 (sparse matrices don't explicitly store NaNs in data, but if loaded from dense with NaNs)
        # This step is more relevant if X was dense and had NaNs before becoming sparse.
        # If X is already sparse and came from a source without NaNs, this might not change much.
        # A more robust way for sparse is to ensure no NaNs in data array if they exist
        if hasattr(adata_dna.X, 'data') and np.isnan(adata_dna.X.data).any():
            print("  Filling NaNs in sparse matrix data array with 0...")
            adata_dna.X.data[np.isnan(adata_dna.X.data)] = 0
            adata_dna.X.eliminate_zeros() # Clean up explicit zeros

        # Change 3s to 0s
        print("  Changing 3s (Unknown) to 0s (WT) in sparse matrix...")
        adata_dna.X.data[adata_dna.X.data == 3] = 0
        adata_dna.X.eliminate_zeros() # Clean up after modification
    else: # Dense matrix
        print("  Matrix is dense.")
        print("  Filling NaNs with 0...")
        adata_dna.X = np.nan_to_num(adata_dna.X, nan=0.0)
        
        print("  Changing 3s (Unknown) to 0s (WT)...")
        adata_dna.X[adata_dna.X == 3] = 0
    
    print("  Preprocessing of .X complete.")

    # --- 2. Identify "Germline-like" Variants (Frequent Variants) ---
    print("\nStep 2: Identifying frequent variants (potential germline markers)...")
    
    # Calculate the fraction of cells where each variant is non-zero (i.e., Het or Hom)
    if issparse(adata_dna.X):
        # For sparse, count non-zero elements per column (variant)
        variant_counts_non_zero = adata_dna.X.getnnz(axis=0) # Number of non-zero entries per variant
        variant_frequency = variant_counts_non_zero / adata_dna.n_obs
    else: # Dense matrix
        variant_frequency = np.mean(adata_dna.X != 0, axis=0)
        
    frequent_variant_mask = variant_frequency >= min_cell_frequency_for_germline
    frequent_variant_names = adata_dna.var_names[frequent_variant_mask].tolist()

    if not frequent_variant_names:
        print(f"Error: No variants found appearing in > {min_cell_frequency_for_germline*100}% of cells. Cannot proceed.")
        return None
    
    print(f"  Found {len(frequent_variant_names)} frequent variants (out of {adata_dna.n_vars}).")

    # --- 3. Subset AnnData to Frequent Variants ---
    print("\nStep 3: Subsetting AnnData to frequent variants...")
    adata_frequent_vars = adata_dna[:, frequent_variant_names].copy()
    print(f"  Subsetted data shape: {adata_frequent_vars.shape}")

    # --- 4. Clustering ---
    print("\nStep 4: Performing PCA and K-Means clustering...")
    
    # Ensure data is suitable for PCA (e.g., convert to dense if sparse and PCA needs it)
    # Scanpy's PCA can handle sparse data, but let's ensure it's float for stability.
    if issparse(adata_frequent_vars.X):
        X_for_pca = adata_frequent_vars.X.astype(np.float32)
    else:
        X_for_pca = adata_frequent_vars.X.astype(np.float32)

    # Create a temporary AnnData for PCA if X_for_pca is just the array
    # Or, if adata_frequent_vars.X was modified, it can be used directly.
    # For simplicity, let's use the .X directly from adata_frequent_vars
    
    # PCA
    # Adjust n_comps: cannot be more than min(n_obs, n_vars)
    actual_n_pcs = min(n_pcs_for_clustering, adata_frequent_vars.n_obs -1, adata_frequent_vars.n_vars -1)
    if actual_n_pcs < 2: # K-Means on 1D is possible but less common for this type of problem
        print(f"Warning: Number of PCs for clustering is {actual_n_pcs}. Clustering might be suboptimal.")
        if actual_n_pcs <=0:
             print("Error: Cannot perform PCA with <=0 components. Check data dimensions and n_pcs_for_clustering.")
             return None


    print(f"  Running PCA with {actual_n_pcs} components...")
    try:
        sc.tl.pca(adata_frequent_vars, n_comps=actual_n_pcs, svd_solver='arpack', zero_center=True) # zero_center is default True for dense
    except Exception as e_pca:
        print(f"  Error during PCA: {e_pca}. Attempting with zero_center=None for sparse if applicable.")
        try:
            sc.tl.pca(adata_frequent_vars, n_comps=actual_n_pcs, svd_solver='arpack', zero_center=None if issparse(adata_frequent_vars.X) else True)
        except Exception as e_pca2:
            print(f"  PCA failed again: {e_pca2}. Cannot proceed with clustering.")
            return None


    # K-Means Clustering on PCA results
    print(f"  Running K-Means clustering (k={n_donors_to_cluster}) on PCA results...")
    pca_coordinates = adata_frequent_vars.obsm['X_pca']
    
    kmeans = KMeans(n_clusters=n_donors_to_cluster, random_state=kmeans_random_state, n_init=10) # n_init='auto' in newer sklearn
    cluster_labels = kmeans.fit_predict(pca_coordinates)
    
    # Store cluster labels in the subsetted AnnData
    adata_frequent_vars.obs[donor_cluster_col] = pd.Categorical(cluster_labels.astype(str))
    print("  Clustering complete.")
    print(f"  Cluster distribution:\n{adata_frequent_vars.obs[donor_cluster_col].value_counts()}")


    sc.pp.neighbors(adata_frequent_vars, n_pcs=min(10, adata_frequent_vars.obsm['X_pca'].shape[1]), use_rep='X_pca')
    sc.tl.umap(adata_frequent_vars)
    sc.pl.umap(adata_frequent_vars, color=donor_cluster_col, title="UMAP of Frequent Variants (Sparse) by Inferred Donor")
    plt.savefig("plots/umap_frequent_variants_sparse.png", dpi=300)
    # --- 5. Add Cluster Labels to Original AnnData ---
    # We add it to the copy we made at the beginning (adata_dna)
    # or to the input object directly if preferred (adata_dna_input)
    print(f"\nStep 5: Adding '{donor_cluster_col}' to the input AnnData object's .obs...")
    adata_dna_input.obs[donor_cluster_col] = adata_frequent_vars.obs[donor_cluster_col].reindex(adata_dna_input.obs_names)
    
    print("--- Donor Clustering Complete ---")
    return adata_dna_input


# --- Example Usage ---
if __name__ == '__main__':
    
    adata_loom = sc.read("/mnt/data/project/25_04_29_Figure3_reanalysis/processed_data/merged_v9/Annotated_genotype_loom_MERGED.h5ad")
    adata_dna_dummy_sparse_clustered = cluster_donors_from_genotypes(adata_loom)

    adata_rna = sc.read("/mnt/data/project/25_04_29_Figure3_reanalysis/processed_data/merged_v9/Annotated_phenotype_hypr_seq_MERGED.h5ad")
    adata_rna.obs['inferred_donor_cluster'] = adata_dna_dummy_sparse_clustered.obs['inferred_donor_cluster'].reindex(adata_rna.obs_names)
    adata_rna.write("/mnt/data/project/25_04_29_Figure3_reanalysis/processed_data/merged_v9/Annotated_phenotype_hypr_seq_MERGED_with_donor_clusters.h5ad")
    if adata_dna_dummy_sparse_clustered is not None:
        print("\nClustering results on SPARSE dummy data (head of .obs):")
        print(adata_dna_dummy_sparse_clustered.obs.head())
        print("\nCluster counts:")
        print(adata_dna_dummy_sparse_clustered.obs[donor_cluster_col].value_counts())

        # Optional: Visualize UMAP of frequent variants colored by cluster
        # Re-calculate frequent_variant_mask based on the potentially modified sparse matrix
        if issparse(adata_dna_dummy_sparse_clustered.X):
            variant_counts_non_zero_s = adata_dna_dummy_sparse_clustered.X.getnnz(axis=0)
            variant_frequency_s = variant_counts_non_zero_s / adata_dna_dummy_sparse_clustered.n_obs
        else:
            variant_frequency_s = np.mean(adata_dna_dummy_sparse_clustered.X != 0, axis=0)
        frequent_variant_mask_s = variant_frequency_s >= min_cell_frequency_for_germline
    

In [None]:
# Then, compute DEGs

import scanpy as sc
import pandas as pd
import numpy as np
from pathlib import Path
import re

# --- Configuration ---
adata_file = '/mnt/data/project/25_04_29_Figure3_reanalysis/processed_data/merged_v9/Annotated_phenotype_hypr_seq_MERGED_with_donor_clusters.h5ad'
# geno = sc.read("/mnt/data/project/25_04_29_Figure3_reanalysis/data_anders/tapestri_adata.h5ad")
genotype_col = 'genotype_annotation'
condition_col = 'condition'
wt_genotype_label = 'AAV_control' # Genotype to compare against

# DE Analysis Parameters
de_method = 'wilcoxon'
de_corr_method = 'benjamini-hochberg'
# No filtering thresholds needed here as we save all genes

# Output directory and file
output_dir = Path(f'./analysis_results_v9/deg_vs_{wt_genotype_label}_wide_table') # Updated output dir name
output_dir.mkdir(parents=True, exist_ok=True)
output_filename = output_dir / f"all_pairwise_deg_stats_vs_{wt_genotype_label}_wide.csv"
# --- 1. Load Data and Prepare ---
print(f"Loading full annotated data from: {adata_file}")
try:
    if isinstance(adata_file, (str, Path)):
        adata_full = sc.read_h5ad(adata_file)
        # adata_full.obs['inferred_donor_cluster'] = geno.obs['genotype_cluster']
        # sc.pp.filter_cells(adata_full, min_counts=500)  # Filter genes with <3 cells
    elif isinstance(adata_file, sc.AnnData):
        adata_full = adata_file # Use the object directly if passed
    else:
        raise TypeError("adata_file must be a path string or an AnnData object.")
    print("Full data loaded successfully.")

    required_obs_cols = [genotype_col, condition_col]
    for col in required_obs_cols:
        if col not in adata_full.obs.columns:
            raise ValueError(f"Required column '{col}' not found in adata_full.obs.")

    # Prepare expression data for DE (log-normalized)
    if 'log1p' in adata_full.uns and adata_full.X is not None and hasattr(adata_full.X, 'expm1'): # Check if log1p was likely applied
        adata_processed = adata_full.copy()
        print("Using expression from adata_full.X (assumed log-normalized).")
    elif adata_full.X is not None:
        print("Using adata_full.X for DE processing (will normalize and log1p if needed)...")
        adata_processed = adata_full.copy()
        sc.pp.normalize_total(adata_processed, target_sum=1e4)
        sc.pp.log1p(adata_processed)
        print("  Normalization and log1p applied to .X data if needed.")
    else:
        raise ValueError("Suitable expression data not found in .X or .raw for processing.")
    
    adata_processed.obs[genotype_col] = adata_processed.obs[genotype_col].astype('category')
    adata_processed.obs[condition_col] = adata_processed.obs[condition_col].astype('category')

except FileNotFoundError:
    print(f"Error: AnnData file not found at {adata_file}.")
    exit()
except (ValueError, TypeError) as ve:
     print(f"Error: {ve}")
     exit()
except Exception as e:
    print(f"An error occurred loading/preparing the AnnData file: {e}")
    exit()

# --- 2. Perform Pairwise DE Analysis and Collect Results for Merging ---
# compute DE for only donor 0
adata_processed = adata_processed[adata_processed.obs['inferred_donor_cluster'] == "0"].copy()

list_of_de_dataframes_for_merge = []

unique_conditions = adata_processed.obs[condition_col].cat.categories.tolist()
unique_genotypes = adata_processed.obs[genotype_col].cat.categories.tolist()
print(f"\nFound conditions: {unique_conditions}")
print(f"Found genotypes: {unique_genotypes}")

for current_cond in unique_conditions:
    print(f"\n--- Processing Condition: {current_cond} ---")
    adata_cond = adata_processed[adata_processed.obs[condition_col] == current_cond].copy()
    
    if wt_genotype_label not in adata_cond.obs[genotype_col].cat.categories:
        print(f"  Warning: WT genotype '{wt_genotype_label}' not found in condition '{current_cond}'. Skipping DE for this condition.")
        continue

    for current_geno in unique_genotypes:
        if current_geno == wt_genotype_label:
            continue 

        print(f"  Comparing: '{current_geno}' vs '{wt_genotype_label}'")
        groups_to_compare = [current_geno, wt_genotype_label]
        adata_comparison = adata_cond[adata_cond.obs[genotype_col].isin(groups_to_compare)].copy()
        
        group_counts = adata_comparison.obs[genotype_col].value_counts()
        if len(group_counts) < 2 or group_counts.get(current_geno, 0) < 3 or group_counts.get(wt_genotype_label, 0) < 3:
            print(f"    Skipping: Insufficient cells or groups for DE. Counts: {group_counts.to_dict()}")
            continue
            
        adata_comparison.obs[genotype_col] = adata_comparison.obs[genotype_col].astype('category').cat.set_categories(groups_to_compare)

        try:
            sc.tl.rank_genes_groups(
                adata_comparison,
                groupby=genotype_col,
                groups=[current_geno], 
                reference=wt_genotype_label,
                method=de_method,
                use_raw=False, 
                corr_method=de_corr_method,
                n_genes=adata_comparison.n_vars 
            )
            
            degs_df_full = sc.get.rank_genes_groups_df(adata_comparison, group=current_geno)
            
            # Prepare DataFrame for merging: index=gene_names, specific columns for logFC and pval_adj
            df_for_merge = degs_df_full[['names', 'logfoldchanges', 'pvals_adj']].copy()
            df_for_merge = df_for_merge.set_index('names')
            
            # Create specific column names
            # Sanitize current_geno for column names (replace special characters)
            sanitized_geno_name = re.sub(r'[^A-Za-z0-9_]+', '_', current_geno)
            
            logfc_col_name = f"{current_cond}_{sanitized_geno_name}_vs_{wt_genotype_label}_logFC"
            pval_adj_col_name = f"{current_cond}_{sanitized_geno_name}_vs_{wt_genotype_label}_pval_adj"
            
            df_for_merge = df_for_merge.rename(columns={
                'logfoldchanges': logfc_col_name,
                'pvals_adj': pval_adj_col_name
            })
            
            list_of_de_dataframes_for_merge.append(df_for_merge)
            
            print(f"    Processed DE for {len(df_for_merge)} genes. Stats will be added to the wide table.")

        except Exception as e:
            print(f"    Error during DE analysis for '{current_geno}' vs '{wt_genotype_label}' in condition '{current_cond}': {e}")

# --- 3. Merge all DE results into a single wide DataFrame ---
print("\n--- Merging all DE results into a single wide table ---")
if list_of_de_dataframes_for_merge:
    # Concatenate along axis=1 (columns), joining on the gene index
    # 'outer' join ensures all genes from all comparisons are kept
    final_wide_df = pd.concat(list_of_de_dataframes_for_merge, axis=1, join='outer')
    final_wide_df = final_wide_df.sort_index() # Sort by gene name
    
    print(f"Final wide DataFrame shape: {final_wide_df.shape}")
    print("Head of the final wide DataFrame:")
    print(final_wide_df.head())

    # Save the wide DataFrame
    try:
        final_wide_df.to_csv(output_filename)
        print(f"  Successfully saved wide DEG table to: {output_filename}")
    except Exception as e:
        print(f"  Error saving wide DEG table: {e}")
else:
    print("No DE results were generated to merge and save.")

print("\n--- Pairwise DE Analysis (Single Wide Table) Complete ---")

# save the genotype counts per condition in the output folder
# we have 2 conditions and many genotypes, so we will count how many cells are in each genotype per condition
genotype_counts = adata_processed.obs.groupby([condition_col, genotype_col]).size().unstack(fill_value=0)
genotype_counts_filename = output_dir / "genotype_counts_per_condition.csv"
try:
    genotype_counts.to_csv(genotype_counts_filename)
    print(f"  Successfully saved genotype counts per condition to: {genotype_counts_filename}")
except Exception as e:
    print(f"  Error saving genotype counts per condition: {e}")
# --- End of Script ---