In [None]:
# This script takes the Hypr-seq adata and the loom file as input
# And output an scRNA-seq data with detailed annotations

import scanpy as sc
import numpy as np
import pandas as pd

import argparse
import seaborn as sns
import sys
import logging
import configparser
import anndata
import matplotlib.pyplot as plt
import re
import matplotlib.patches as mpatches
import os
import argparse
import copy
import h5py

# Configure logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)

# Read and filter the Hypr-seq anndata
def filter_hypr(adata, probe_level=False):
    # Remove cells with less than 1000 UMIs per cell
    sc.pp.filter_cells(adata, min_counts=1000)
    # merge the probe name, making the cell by probe matrix to cell by gene matrix
    # Splitting at the underscore to separate gene names from probe numbers
    if not probe_level:
        gene_names = [name.split('_')[0] for name in adata.var_names]
    else:
        gene_names = [name for name in adata.var_names]

    # Add the gene names as a new column in the DataFrame of the .var slot
    adata.var['gene_name'] = gene_names
    # Convert the AnnData to a DataFrame for easier manipulation
    adata.X = adata.X.toarray()
    adata_df = pd.DataFrame(adata.X.T, index=adata.var_names, columns=adata.obs_names)

    # Use the gene names to sum the counts
    # Group by the new gene names and sum across columns (probes for the same gene)
    aggregated_data = adata_df.groupby(adata.var['gene_name']).sum()

    # Transpose back to original shape (samples as rows, genes as columns)
    aggregated_data = aggregated_data.T

    # Create new AnnData object with the aggregated data
    adata_aggregated = sc.AnnData(X=aggregated_data)
    # Copy the metadata from the original AnnData object
    adata_aggregated.obs = adata.obs.copy()
    # Optionally, copy over any relevant .uns data (unsupervised annotations, such as PCA, neighbors, etc.)
    #adata_aggregated.uns = adata.uns.copy()

    return adata_aggregated




# function reads the loomfile downloaded from Tapestri portal
def read_tapestri_H5(filename):
    """
    Read data from MissionBio's formatted loom file.

    Parameters
    ----------
    filename : str
        Path to the loom file (.loom)

    Returns
    -------
    anndata.AnnData
        An anndata object with the following layers:
        adata.X: GATK calls
        adata.layers['e']: reads with evidence of mutation
        adata.layers['no_e']: reads without evidence of mutation
    """

    with h5py.File(filename, "r") as file:
        # get the variant name, amplicon, chromosome, location]
        # import pdb; pdb.set_trace()
        variant_names = file['assays']["dna_variants"]["ca"]['id'][:] # binary
        variant_names = np.array([i.decode("utf-8") for i in variant_names])


        amplicon_names = file['assays']['dna_variants']['ca']['amplicon'][:] # binary
        amplicon_names = np.array([i.decode("utf-8") for i in amplicon_names])

        chromosome = file['assays']['dna_variants']['ca']['CHROM'][:] # binary
        chromosome = np.array([i.decode("utf-8") for i in chromosome])


        location = file['assays']['dna_variants']['ca']['POS'][:]

        # get the barcode
        barcodes = file['assays']['dna_variants']['ra']['barcode'][:] # binary
        barcodes = np.array([i.decode("utf-8") for i in barcodes])

        mutation_matrix = file['assays']['dna_variants']['layers']['NGT'][:,:]

        adata = anndata.AnnData(X=mutation_matrix, dtype=np.int8)
        adata.obs_names = barcodes
        adata.var_names = variant_names
        adata.varm["amplicon"] = amplicon_names
        adata.varm["chrom"] = chromosome
        adata.varm["loc"] = location
    file.close()

    return adata



# Optional function, visualize the barcode distribution
def barcode_rank_plot(adata, minimum=0, xmax=None):

    # Sum the UMIs for each cell
    cell_umi_counts_all = adata.X.sum(axis=1)
    cell_idxs = np.argwhere(cell_umi_counts_all > minimum)
    cell_umi_counts = cell_umi_counts_all[cell_idxs]

    # Convert to numpy array if it's not already
    cell_umi_counts = np.array(cell_umi_counts).flatten()

    # Sort the UMI counts in descending order for the rank plot
    sorted_umi_counts = np.sort(cell_umi_counts)[::-1]

    plt.figure(figsize=(5, 4), dpi=150)
    sns.lineplot(x=range(1, len(sorted_umi_counts) + 1), y=sorted_umi_counts)
    plt.xlabel('Barcode Rank')
    plt.ylabel('UMI Count')
    plt.title('Cell Barcode Rank Plot')
    plt.yscale('log')  # Log scale for better visualization
    plt.xscale('log') # Log scale
    if xmax is not None:
        plt.xlim(0, xmax)
    plt.show()


# Function for finding the intersecting barcodes between the modalities
# We can drop the idea of mudata for simplity
def find_intersecting_and_filter(adata_hypr, adata_h5):
    """
    Find the intersecting barcode, reorder, and return new AnnData objects for both datasets.
    """
    # Find intersecting barcodes
    obs_1, obs_2 = adata_hypr.obs_names, adata_h5.obs_names
    cmn_barcodes, idx_1, idx_2 = np.intersect1d(obs_1, obs_2, return_indices=True)

    # Logging the number of barcodes
    logger.info(f"Found {len(obs_1)} barcodes in modality hypr seq")
    logger.info(f"Found {len(obs_2)} barcodes in modality loom file")
    logger.info(f"Found {len(idx_1)} intersecting barcodes")

    # Subset and reorder both datasets based on the intersecting indices
    adata_hypr_new = adata_hypr[idx_1, :]
    adata_h5_new = adata_h5[idx_2, :]

    # # Ensure the order of barcodes is the same in both datasets
    # adata_hypr_new = adata_hypr_new[cmn_barcodes, :]
    # adata_h5_new = adata_h5_new[cmn_barcodes, :]

    return adata_hypr_new, adata_h5_new
    

# Step 1, determine the germline mutation from the loom file
# We envision that if certain mutation appeared too many times in the dataset
# (e.g., more than 25% cells have the homo mutation, we call it homo germline)
def call_germline_mutations(adata_h5, cutoff=0.25):

    fraction_het = np.mean(adata_h5.X == 1, axis=0)
    het_germline = adata_h5.var_names[fraction_het > cutoff]

    fraction_hom = np.mean(adata_h5.X == 2, axis=0)
    hom_germline = adata_h5.var_names[fraction_hom > cutoff]

    return list(het_germline) + list(hom_germline)


def call_germline_mutations_from_DNA_seq(donor1_h5, donor2_h5):
    """
    Given the donor B/C dna-seq h5 file, this function extract the germline mutations    
    """

    def call_germline_mutations_for_donor(donor1_h5, donor="B"):
        with h5py.File(donor1_h5, "r") as file:
            # remove the bad amplicons
            amplicon_names = file['assays']['dna_variants']['ca']['amplicon'][:] # binary
            amplicon_names = np.array([i.decode("utf-8") for i in amplicon_names])
            amplicon_bad_c = ["TAMPL86172", 'TAMPL86171', 'TAMPL86243', 'TAMPL86176', 'TAMPL86144', 'TAMPL86221', 'TAMPL86163', 'TAMPL86186', 'TAMPL86231', 
                                'TAMPL86189', 'TAMPL86213', 'TAMPL86149', 'TAMPL86140', 'TAMPL86164', 'TAMPL86228', 'TAMPL86150', 'TAMPL86179', 'TAMPL86165', 
                                'TAMPL86240', 'TAMPL86181', 'TAMPL86183', 'TAMPL86234', 'TAMPL86126', 'TAMPL86222', 'TAMPL86191', 'TAMPL86210', 'TAMPL86167',
                                'TAMPL86190', 'TAMPL86128', 'TAMPL86199', 'TAMPL86143', 'TAMPL86141', 'TAMPL86136', 'TAMPL86135', 'TAMPL86214', 'TAMPL86223', 
                                'TAMPL86220', 'TAMPL86207', 'TAMPL86217', 'TAMPL86129', 'TAMPL86195', 'TAMPL86185', 'TAMPL86224', 'TAMPL86173', 'TAMPL86198',
                                'TAMPL86142', 'TAMPL86219']
            amplicon_bad_b = ["TAMPL86200", "TAMPL86172", 'TAMPL86171', 'TAMPL86243', 'TAMPL86144', 'TAMPL86221', 'TAMPL86163', 'TAMPL86186','TAMPL86231',
                                'TAMPL86176', 'TAMPL86140', 'TAMPL86149', 'TAMPL86179','TAMPL86213','TAMPL86189','TAMPL86234','TAMPL86150','TAMPL86165','TAMPL86190',
                                'TAMPL86240','TAMPL86228','TAMPL86126','TAMPL86167','TAMPL86181','TAMPL86183','TAMPL86210','TAMPL86229','TAMPL86222','TAMPL86191',
                                'TAMPL86199','TAMPL86128','TAMPL86233','TAMPL86232','TAMPL86135','TAMPL86195','TAMPL86141','TAMPL86143','TAMPL86207','TAMPL86166',]
            
            if donor == "B":
                amplicon_bad = amplicon_bad_b
            elif donor == "C":
                amplicon_bad = amplicon_bad_c
            
            # List comprehension to find indices of good amplicons
            good_indices = [index for index, amplicon in enumerate(amplicon_names) if amplicon not in amplicon_bad]

            print("Indices of good amplicons:", good_indices)
            # get the variant names
            variant_names = file['assays']["dna_variants"]["ca"]['id'][:] # binary
            variant_names = np.array([i.decode("utf-8") for i in variant_names])

            # only keep good variant
            variant_names = variant_names[good_indices]

            # get the mutation matrix
            mutation_matrix = file['assays']['dna_variants']['layers']['NGT'][:,:]
            assert mutation_matrix.shape[0] == 1
            mutation_matrix = mutation_matrix[0]

            # only keep the good variant
            mutation_matrix = mutation_matrix[good_indices]
            # determine the het mutations
            het_donor1 = np.where(mutation_matrix == 1)
            het_germline_donor1 = variant_names[het_donor1]
            # hom germline mutations
            hom_donor1 = np.where(mutation_matrix == 2)
            hom_germline_donor1 = variant_names[hom_donor1]
            
            donor1_germline_mutations = list(het_germline_donor1) + list(hom_germline_donor1)
            # import pdb; pdb.set_trace()
            
            file.close()
        return donor1_germline_mutations
    
    donor1_germline_mutations = call_germline_mutations_for_donor(donor1_h5, "B")
    donor2_germline_mutations = call_germline_mutations_for_donor(donor2_h5, "C")

    return donor1_germline_mutations, donor2_germline_mutations






# For AAV control, we may ignore mixed cells and bystander effect
def call_AAV_control_edits(adata_h5, config_path='IRF4.ini'):
    """
    Input: Tapestri loom file and mutation configuration 
    Current we only accept 1 type of AAV control (should be a continuous region)
    Output: All the AAV edits found in the loom file. 

    This function will not modify adata_h5
    """
    # Read configuration settings
    config = configparser.ConfigParser()
    config.read(config_path)
    
    chrom_name = config.get('AAV_ACBE3', 'chrom_name')
    start = config.getint('AAV_ACBE3', 'start')
    end = config.getint('AAV_ACBE3', 'end')
    variant_alleles = config.get('AAV_ACBE3', 'variant_alleles')
    
    # Parsing variant names
    variant_names = np.asarray(adata_h5.var_names.values)
    chrom = [name.split(':')[0] for name in variant_names]
    # import pdb; pdb.set_trace()
    loc = [int(name.split(':')[1]) for name in variant_names]
    edit_type = [name.split(':')[2] for name in variant_names]

    control_editing = []
    complement_rule = {'C': 'G', 'G': 'C', 'A': 'T', 'T': 'A'}
    try:
        reverse_variant_alleles = complement_rule[variant_alleles[0]] + "/" + complement_rule[variant_alleles[-1]]
    except:
        raise ValueError(f"The variant alleles is not supported. Current variant alleles is {variant_alleles}")
    # Identify control editing cells
    for i in range(len(chrom)):
        if chrom[i] == chrom_name and start <= loc[i] < end+1:
            if edit_type[i] in [variant_alleles, reverse_variant_alleles]:
                control_editing.append(variant_names[i])

    return control_editing


def call_AAV_PE4_edits(adata_h5, config_path):
    """
    Input: Tapestri loom file and mutation configuration 
    Output: All the AAV edits found in the loom file. 
    """
    # Read configuration settings
    config = configparser.ConfigParser()
    config.read(config_path)
    
    chrom_name = config.get('AAVS-PE4max', 'chrom_name')
    locus = config.getint('AAVS-PE4max', 'locus')
    
    
    # Parsing variant names
    variant_names = np.asarray(adata_h5.var_names.values)
    chrom = [name.split(':')[0] for name in variant_names]
    loc = [int(name.split(':')[1]) for name in variant_names]
    edit_type = [name.split(':')[2] for name in variant_names]

    pe4_editing = []
    
    # Identify control editing cells
    for i in range(len(chrom)):
        if chrom[i] == chrom_name and loc[i] == locus:

            if edit_type[i] in ["C/T", "C/*", "G/A", "G/*"]:
                pe4_editing.append(variant_names[i])
            
    return pe4_editing




def call_AAV_cells(adata_hypr, adata_h5, AAV_editing):
    """
    Annotate the hypr adata (scRNA-seq data) by the AAV_edit found in call_AAV_edits.
    We do not perform annotation in the function to ensure consistency
    """
    all_idx = []
    for edit in AAV_editing:
        # get the idx of the edit we found
        idx = adata_h5.var_names.get_loc(edit)
        # if the value is 1/2 (hete/homo), we annotate the cell as control edit cell
        het_idx, hom_idx = [np.flatnonzero(adata_h5.X[:, idx] == i) for i in (1, 2)]
        AAV_edit_idx = np.concatenate([het_idx, hom_idx])
        AAV_cells_barcode_loom = adata_h5.obs_names[AAV_edit_idx]
        # find the AAV cells index in that of the adata_hypr
        _, idx1, _ = np.intersect1d(adata_hypr.obs_names, AAV_cells_barcode_loom, return_indices=True)
        all_idx.append(idx1)
    return np.unique(np.concatenate(all_idx))

def call_AAV4GPPKO_edits(adata_h5, config_path):
    """
    Input: Tapestri loom file and mutation configuration 
    Current we only accept 1 type of AAV control (should be a continuous region)
    Output: All the AAV edits found in the loom file. 
    """
    # Read configuration settings
    config = configparser.ConfigParser()
    config.read(config_path)
    
    chrom_name = config.get('AAVS4GPP-KO', 'chrom_name')
    start = config.getint('AAVS4GPP-KO', 'start')
    end = config.getint('AAVS4GPP-KO', 'end')
    
    # Parsing variant names
    variant_names = np.asarray(adata_h5.var_names.values)
    chrom = [name.split(':')[0] for name in variant_names]
    loc = [int(name.split(':')[1]) for name in variant_names]
    edit_type = [name.split(':')[2] for name in variant_names]

    del_editing = []
    
    # Identify  editing cells
    for i in range(len(chrom)):
        if chrom[i] == chrom_name and start <= loc[i] < end+1:
            # import pdb; pdb.set_trace()
            
            del_editing.append(variant_names[i])
    # import pdb; pdb.set_trace()
    return del_editing



def call_TNRC18KO_edits(adata_h5, config_path):
    """
    Input: Tapestri loom file and mutation configuration 
    Current we only accept 1 type of AAV control (should be a continuous region)
    Output: All the AAV edits found in the loom file. 
    """
    # Read configuration settings
    config = configparser.ConfigParser()
    config.read(config_path)
    
    chrom_name = config.get('TNRC18-KO', 'chrom_name')
    start = config.getint('TNRC18-KO', 'start')
    end = config.getint('TNRC18-KO', 'end')
    
    # Parsing variant names
    variant_names = np.asarray(adata_h5.var_names.values)
    chrom = [name.split(':')[0] for name in variant_names]
    loc = [int(name.split(':')[1]) for name in variant_names]
    edit_type = [name.split(':')[2] for name in variant_names]

    del_editing = []
    
    # Identify  editing cells
    for i in range(len(chrom)):
        if chrom[i] == chrom_name and start <= loc[i] < end+1:
            if edit_type[i][-1] == "*":
                del_editing.append(variant_names[i])

    return del_editing





def get_nearby_variants(
    variant_names, # List of all variant in the loom file
    target_loc, # loci of interest
    germline_amplicons, # List of Germline mutations 
    window_size # window size
):
    """
        Get nearby variants within a window size of the target_loc.
        Args:
            variant_names (list): List of variant names.
            target_loc (str): Target loci (e.g. chr1:123456:A/G).
            germline_amplicons (list): List of germline amplicons.
            window_size (int): Window size.
        Returns:
            list: List of nearby variants.
    """
    chrom_t, loc_t, _ = target_loc.split(":")
    nearby_variants = []
    # import pdb; pdb.set_trace()
    if target_loc in germline_amplicons:
        germline_amplicons.remove(target_loc)
    for variant in variant_names:
        chrom, loc, edit_type = variant.split(":")
        is_valid_edit_type = not edit_type.endswith("*") and re.match(
            r"^[A-Za-z]/[A-Za-z]$", edit_type
        )
        is_not_germline = variant not in germline_amplicons
        if (
            chrom == chrom_t
            and (int(loc_t) - window_size) <= int(loc) <= (int(loc_t) + window_size)
            and is_valid_edit_type
            and is_not_germline
        ):
            nearby_variants.append(variant)

    return nearby_variants



def plot_mutant_types(target_loc, new_matrix, mt_type, snp_name, save_path, index=""):
    """
    Plot the nearby genotype heatmap
    """
    mutation_counts = new_matrix.apply(np.count_nonzero, axis=0)
    plt.figure(figsize=(10, 6))
    bars = plt.bar(mutation_counts.index, mutation_counts.values)
    for i, idx in enumerate(mutation_counts.index):
        plt.text(
            bars[i].get_x() + bars[i].get_width() / 2,
            bars[i].get_height(),
            str(mutation_counts.values[i]),
            ha="center",
            va="bottom",
        )
    if target_loc in mutation_counts.index:
        bars[mutation_counts.index.tolist().index(target_loc)].set_color("r")
    red_patch = mpatches.Patch(color="red", label="Target Loci")
    plt.legend(handles=[red_patch])

    # we need to rename the target_loc to ensure we do not introduce extra path problems
    target_loc = target_loc.replace("/", "-")


    plot_title = (
        f"Mutation Counts: {mt_type} - {target_loc} ({index}) with total cells {len(new_matrix)}"
    )
    plt.title(plot_title)
    plt.xticks(rotation=90)
    plt.xlabel("")
    plt.ylabel("Number of Mutated Cells")

    save_index = f"{snp_name}_{mt_type}_mutant_counts_{index}.pdf"
    plt.savefig(os.path.join(save_path, save_index), bbox_inches="tight")
    plt.close()

    # Plot a cluster map of matrix but cluster rows only.
    sns.clustermap(
        new_matrix,
        cmap="YlGnBu",
        row_cluster=True,
        col_cluster=False,
        figsize=(10, 10),
    )
    plt.title(
        f"Mutation Matrix: {mt_type} - {snp_name} aka {target_loc} with shape {new_matrix.shape}"
    )
    plt.xlabel("Variant")
    plt.ylabel("Cell Barcode")
    save_index = f"{snp_name}_{mt_type}_mutant_matrix_{index}.pdf"
    plt.savefig(os.path.join(save_path, save_index), bbox_inches="tight")
    plt.close()




def process_bystander_cells(target_loc, cells, nearby_variants, adata_h5):
    """
    Identify which cells have bystander editing.
    Args:
        target_loc (str): Target loci (e.g. chr1:123456:A/G).
        cells (list): List of cells (that supposedly have the mutation).
        nearby_variant: adata that includes only the nearby cells
    Returns:
        list: List of pure cells.
        pd.DataFrame: New matrix.
        pd.DataFrame: Matrix with false amplicons.
    """

    nearby_adata = adata_h5[cells, nearby_variants].copy()
    new_matrix = pd.DataFrame(
        nearby_adata.X, index=nearby_adata.obs_names, columns=nearby_adata.var_names
    )
    # Change all 3 to 0 in the new matrix, but excluding the target loci column.
    for col in new_matrix.columns:
        if col != target_loc:
            new_matrix[col] = new_matrix[col].apply(
                lambda val: 0 if val == 3 else val
            )

    # Remove the target loci column to generate a bystander position only matrix.
    matrix_without_target_loci = new_matrix.loc[
        :, new_matrix.columns != target_loc
    ].copy()
    bystander_sum_for_every_cell = matrix_without_target_loci.sum(axis=1)
    # Remove cells that have a row sum of 0, or there's no bystander editing.
    matrix_without_target_loci = matrix_without_target_loci.loc[
        bystander_sum_for_every_cell != 0
    ]

    # If there's no bystander editing, return the original cells.
    if len(matrix_without_target_loci) == 0:
        # Set the bystander as the empty set
        return set(cells), set(), new_matrix

    # Everything that's left is a bystander cell.
    pure_cells = set(cells) - set(matrix_without_target_loci.index)
    return pure_cells, set(matrix_without_target_loci.index), new_matrix





def call_single_loci_cells(
    target_loc, # the loci of interest
    adata_h5, # the loom matrix to retrieve info
    window_size, # the window size 
    germline_amplicons, # The germline mutations
    snp_name, # the name of the loci
    plot_path, # the path to save the figures
):
    """
    Process a single target loci. 
    We will annotate all cells that include the loci
    The values represent the mutant type:
        0: Background
        1: Pure heterozygous
        2: Heterozygous with bystander editing
        3: Pure homozygous
        4: Homozygous with bystander editing
    """
    # Find the idx of the loci
    idx = adata_h5.var_names.get_loc(target_loc)
    # Annotate cells that are background, het, and hom based on the loom matrix
    bkg_cells, het_cells, hom_cells = [
            adata_h5.obs_names[np.flatnonzero(adata_h5.X[:, idx] == i)] for i in (0, 1, 2)
            ]
    # Find all nearby variant of the give loci
    variant_names = adata_h5.var_names
    nearby_variants = get_nearby_variants(
            variant_names, target_loc, germline_amplicons, window_size
    )

    if target_loc == "chr7:5397122:C/T":
        if 'chr7:5397121:C/T' in nearby_variants:
            nearby_variants.remove("chr7:5397121:C/T")

    if target_loc == "chr7:5397121:C/T":
        if 'chr7:5397122:C/T' in nearby_variants:
            nearby_variants.remove("chr7:5397122:C/T")

    het_pure, het_bystander, het_matrix = process_bystander_cells(
            target_loc, het_cells, nearby_variants, adata_h5
    )
    hom_pure, hom_bystander, hom_matrix = process_bystander_cells(
            target_loc, hom_cells, nearby_variants, adata_h5
    )

    try:
        plot_mutant_types(target_loc, het_matrix, "Heterozygous", snp_name, plot_path, "1-0")
        plot_mutant_types(target_loc, hom_matrix, "Homozygous", snp_name, plot_path, "1-0")
    except:
        print(f"Cannot make plots for the {target_loc}")

    return list(het_bystander), list(hom_bystander), list(het_pure), list(hom_pure)

    # adata_hypr.loc[hom_pure, "genotype"] = f"{target_loc}_homo_pure"
    # adata_hypr.loc[het_pure, "genotype"] = f"{target_loc}_hete_pure"


def annotate_cells(adata):
    """
    Annotates each cell in adata.obs based on the SNP, control, and del columns, according to specified rules.
    
    Parameters:
    - adata: AnnData object containing the obs DataFrame with binary columns for each condition.
    
    Returns:
    - Annotations are stored in a new column 'annotation' in adata.obs.
    """
    
    # List of condition columns, based on the presence of SNPs, control, and del columns in adata.obs
    # snp_columns = [col for col in adata.obs.columns if any(sub in col for sub in ["het_bystander", "het_pure", "hom_bystander", "hom_pure"])]
    snp_columns = [col for col in adata.obs.columns if any(sub in col for sub in ["het_bystander", "hom_bystander"])]
    snp_columns += ["Double_Het", "Double_Hom", "22_het_21_hom", "22_hom_21_het", "22_het", "22_hom", "21_hom", "21_het", "AAV4GPP_KO"]
    control_column = "AAV_ACBE3"
    
    def annotate_row(row):
        # Check SNP-related rules
        snp_active = row[snp_columns].sum()  # Count number of active SNP columns (True values)
        
        if snp_active > 1:  # More than one SNP column is True
            return "mixed"
        
        if row[control_column] and snp_active == 1:  # SNP and control active
            # Find which SNP column is active
            snp_name = row[snp_columns][row[snp_columns] == True].index[0]
            return "mixed"
         
        
        if row[control_column] and snp_active == 0:  # Only control active
            return "AAV_ACBE3"
        
        
        if snp_active == 1:  # Single SNP active, no control or del
            snp_name = row[snp_columns][row[snp_columns] == True].index[0]
            return snp_name
        
        # If no columns are active
        if snp_active == 0 and not row[control_column]:
            return "unedited"
        
        # Default case, though we should not reach here
        return "unknown"
    # import pdb; pdb.set_trace()
    # Apply the annotation function to each row in a vectorized manner
    adata.obs['genotype_annotation'] = adata.obs.apply(annotate_row, axis=1)


def call_donor_dtype(adata_h5, unique_B, unique_C):
    """
    Use the unique germline muations to understand the cell donors
    """
    
    adata_h5_subset_B_mutation = adata_h5[:, adata_h5.var_names.isin(unique_B)]
    adata_h5_subset_C_mutation = adata_h5[:, adata_h5.var_names.isin(unique_C)]
    
    adata_h5_subset_B_mutation.X[adata_h5_subset_B_mutation.X==3] = 0 
    adata_h5_subset_C_mutation.X[adata_h5_subset_C_mutation.X==3] = 0
    
    adata_h5_subset_B_mutation.X[adata_h5_subset_B_mutation.X==2] = 1
    adata_h5_subset_C_mutation.X[adata_h5_subset_C_mutation.X==2] = 1

    B_mutation_sum = adata_h5_subset_B_mutation.X.sum(axis=1)
    C_mutation_sum = adata_h5_subset_C_mutation.X.sum(axis=1)

    B_prob = B_mutation_sum/len(unique_B)
    C_prob = C_mutation_sum/len(unique_C)

    donor = np.array(["Donor B"] * len(adata_h5))
    # import pdb; pdb.set_trace()
    donor[B_prob<C_prob] = "Donor C"

    return donor



def annotate_genotype(adata_h5, 
                    adata_hypr, 
                    config_path,
                    save_path="./",
                    plot_path="./"):
    """
    Annotate the adata_hypr, such that in the adata_hypr.obs, we have:
    Homo pure for SNP 1, 2, 3, …
    Homo bystander for SNP 1,2, 3, …
    Hete pure for SNP 1, 2, 3, …
    Hete bystander for SNP 1, 2, 3…
    Mixed cells
    AAVs: AAV_control, AAV_del
    Unedited. 
    """

    # Perform the annotation step. We need:
    # 1. Determine the germline mutations
    # 2. Annotate the AAV_control cells
    # 3. Annotate the AAV_del_control cells
    # 4. For each SNP we are interested in, find if it has bystander effect
    # 5. Determine the bystander mutations
    # 6.  Annotate all mixed cells. For example, for 2 SNPs we are interested in, 
    #  if they are co-edited, we need to count the number. We need to annotate it as mixed.
    

    # Step 1, determine the germline mutations
    germline_amplicons = call_germline_mutations(adata_h5, cutoff=0.25)

    # Step 2, annotate the AAV_control cells
    # 2.1 Find the control editing locis 
    
    AAV_control_editing = call_AAV_control_edits(adata_h5, config_path=config_path)
    # 2.2 Add a column to the adata_hypr obs and then perform the annotation only here
    adata_hypr.obs["AAV_ACBE3"] = False
    AAV_control_cell_idx = call_AAV_cells(adata_hypr, adata_h5, AAV_editing=AAV_control_editing)
    adata_hypr.obs.iloc[AAV_control_cell_idx, adata_hypr.obs.columns.get_loc("AAV_ACBE3")] = True


    # Step 3, annotate the AAVS4GPP-KO
    # 3.1 Find the del editing locis
    AAV4GPPKO = call_AAV4GPPKO_edits(adata_h5, config_path=config_path)
    # 3.2 Add a column to the adata_hypr and then perform the annotation only here
    adata_hypr.obs["AAV4GPP_KO"] = False
    AAV_del_cell_idx = call_AAV_cells(adata_hypr, adata_h5, AAV_editing=AAV4GPPKO)
    adata_hypr.obs.iloc[AAV_del_cell_idx, adata_hypr.obs.columns.get_loc("AAV4GPP_KO")] = True


    # Step 3.5, annotate the AAVS4GPP-KO
    # 3.6 Find the del editing locis
    # TNRC18_KO = call_TNRC18KO_edits(adata_h5, config_path=config_path)
    # import pdb; pdb.set_trace()
    # # 3.7 Add a column to the adata_hypr and then perform the annotation only here
    # adata_hypr.obs["TNRC18_KO"] = False
    # AAV_del_cell_idx = call_AAV_cells(adata_hypr, adata_h5, AAV_editing=TNRC18_KO)
    # adata_hypr.obs.iloc[AAV_del_cell_idx, adata_hypr.obs.columns.get_loc("TNRC18_KO")] = True

    # import pdb; pdb.set_trace()

    # Step 4/5, Enumerate all other variants we are interested in
    # Initialize the config parser
    config = configparser.ConfigParser()
    # Read the config file
    config.read(config_path)
    # Retrieve and return all section names
    sections = config.sections()
    exclude_sections = ["AAV_ACBE3", "AAVS4GPP-KO", "TNRC18-KO"]
    SNPs = [section for section in sections if section not in exclude_sections]

    # Each snp is the section name we defined in the config file
    # Perform the annotation here
    annotation = ["het_bystander", "het_pure", "hom_bystander", "hom_pure"]
    # SNPs = ["IRF4-rs9392504"]
    # import pdb; pdb.set_trace()
    for snp in SNPs:
        chrom_name = config.get(snp, 'chrom_name')
        locus = config.getint(snp, 'locus')
        alleles = config.get(snp, 'variant_alleles')
        complement_rule = {'C': 'G', 'G': 'C', 'A': 'T', 'T': 'A'}
        try:
            complement_variant_alleles = complement_rule[alleles[0]] + "/" + complement_rule[alleles[-1]]
        except:
            raise ValueError(f"The variant alleles is not supported. Current variant alleles is {alleles}")
        # Transfer back to the original version
        target_loc = ":".join([chrom_name, str(locus), alleles])
        complement_target_loc = ":".join([chrom_name, str(locus), complement_variant_alleles])
        
        # Find which allele is the one we want, the original version or the complement one
        if target_loc in adata_h5.var_names:
            target_loc = target_loc
        elif complement_target_loc in adata_h5.var_names:
            target_loc = complement_target_loc
        else:
            print(f"Both the mutation {target_loc} and the reversed one {complement_target_loc} is not found in the loom file")
            continue


        # Get the index of the cells
        het_bystander, hom_bystander, het_pure, hom_pure = call_single_loci_cells(target_loc, adata_h5, window_size=10, 
            germline_amplicons=germline_amplicons, snp_name = snp, plot_path=plot_path)
        

        for i, idx_i in enumerate([het_bystander, het_pure, hom_bystander, hom_pure]):
            adata_hypr.obs[f"{snp}_{annotation[i]}"] = False
            adata_hypr.obs.loc[idx_i, f"{snp}_{annotation[i]}"] = True

        # adata_hypr.obs[f"{snp}"] = adata_hypr.obs[f"{snp}_het_all"] | adata_hypr.obs[f"{snp}_hom_all"]
    # import pdb; pdb.set_trace()

    # Step 6.5, cross annotate the double Het, double hom, 21Het + 22 Hom, 21 Hom + 22 Het, 21 Het, 21 Hom, 22 Het, 22 Hom, 
    adata_hypr.obs["Double_Het"] = adata_hypr.obs['TNRC18-rs748670681_het_pure'] & adata_hypr.obs['TNRC18-21CT_het_pure']
    adata_hypr.obs["Double_Hom"] = adata_hypr.obs['TNRC18-rs748670681_hom_pure'] & adata_hypr.obs['TNRC18-21CT_hom_pure']
    adata_hypr.obs["22_het_21_hom"] = adata_hypr.obs['TNRC18-rs748670681_het_pure'] & adata_hypr.obs['TNRC18-21CT_hom_pure']
    adata_hypr.obs["22_hom_21_het"] = adata_hypr.obs['TNRC18-rs748670681_hom_pure'] & adata_hypr.obs['TNRC18-21CT_het_pure']
    adata_hypr.obs["22_het"] = adata_hypr.obs['TNRC18-rs748670681_het_pure'] & (~adata_hypr.obs['TNRC18-21CT_het_pure']) & (~adata_hypr.obs['TNRC18-21CT_hom_pure'])
    adata_hypr.obs["22_hom"] = adata_hypr.obs['TNRC18-rs748670681_hom_pure'] & (~adata_hypr.obs['TNRC18-21CT_het_pure']) & (~adata_hypr.obs['TNRC18-21CT_hom_pure'])
    adata_hypr.obs["21_het"] = adata_hypr.obs['TNRC18-21CT_het_pure'] & (~adata_hypr.obs['TNRC18-rs748670681_het_pure']) & (~adata_hypr.obs['TNRC18-rs748670681_hom_pure'])
    adata_hypr.obs["21_hom"] = adata_hypr.obs['TNRC18-21CT_hom_pure'] & (~adata_hypr.obs['TNRC18-rs748670681_het_pure']) & (~adata_hypr.obs['TNRC18-rs748670681_hom_pure'])

    # Step 7, for the cells that has genotypes, we need to ensure that they are not mixed cells
    annotate_cells(adata_hypr)
    # import pdb; pdb.set_trace()
    return adata_hypr, adata_h5
    



def main():
    # Create the parser
    parser = argparse.ArgumentParser(description="Filter, Intersect and Annotate \
     the genotype information to the phenotype RNA-seq count matrix")
    
    # Add arguments for file paths
    parser.add_argument('--tapestri_path', type=str, required=True, help='The path to the h5 file -- genotype')
    parser.add_argument('--hypr_path', type=str, required=True, help='The path to the hypr file -- phenotype')
    parser.add_argument('--config_path', type=str, required=True, help='The path to the variant configuration file')
    parser.add_argument('--save_path', type=str, required=True, help='The path to save annotated anndata')
    parser.add_argument('--plot_path', type=str, required=True, help='The path to save genotype clustermap figure')
    parser.add_argument('--probe_level', type=int, required=True, help="1 for probe level matrix, other numbers for gene_level")
    
    # Parse the arguments
    args = parser.parse_args()
    

    # read the data and set the config path
    adata_h5 = read_tapestri_H5(args.tapestri_path)
    adata_hypr = sc.read(args.hypr_path)
    
    probe_level = (args.probe_level==1)

    adata_hypr = filter_hypr(adata_hypr, probe_level=probe_level)

    config_path = args.config_path
    save_path = args.save_path
    plot_path = args.plot_path

    if not os.path.exists(save_path):
        # Create the directory, including any intermediate directories
        os.makedirs(save_path)
    if not os.path.exists(plot_path):
        # Create the directory, including any intermediate directories
        os.makedirs(plot_path)

    # filter and intersect between hypr matrix and loom matrix
    adata_hypr, adata_h5 = find_intersecting_and_filter(adata_hypr, adata_h5)
    # import pdb; pdb.set_trace()
    # Annotate the filtered data
    adata_hypr, adata_h5 = annotate_genotype(adata_h5=adata_h5, adata_hypr=adata_hypr, config_path=config_path, save_path=save_path, plot_path=plot_path)
    if probe_level:
        hypr_name = f"{save_path}/Annotated_phenotype_hypr_seq_probe.h5ad"
    else:
        hypr_name = f"{save_path}/Annotated_phenotype_hypr_seq.h5ad"
    sc.write(hypr_name, adata_hypr)
    sc.write(f"{save_path}/genotype_loom.h5ad", adata_h5)



if __name__ == "__main__":
    # the conig is TNRC18.ini
    # The Tapestri h5 file is Th1_TNRC18.dna.h5
    # The hypr file is Th1_TNRC18.h5ad
    main() 

In [None]:
# compute DEGs

def compute_deg(input_path, save_path):
    adata = sc.read(input_path)
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    adata.obs['genotype_annotation'] = adata.obs["genotype_annotation"].astype('category')
    compute_pairwise_degs_all(adata, save_path)

def compute_marker_genes_and_plot(adata, condition_1, condition_2, task_name=""):
    """
    We need to ensure that the condition_2 is the control case. Otherwise, the 
    saved parameters/values may not so reliable.
    """
    
    # assert condition_2 in ["AAV_control", "unedited", "AAV_del"]
    # import pdb; pdb.set_trace()
    filtered_adata = adata[(adata.obs["genotype_annotation"]==condition_1) | (adata.obs["genotype_annotation"]==condition_2)]

    filtered_adata.obs['genotype_annotation'] = filtered_adata.obs['genotype_annotation'].astype('category')
    # import pdb; pdb.set_trace()
    sc.tl.rank_genes_groups(filtered_adata, "genotype_annotation", method="wilcoxon")
    # sc.pl.rank_genes_groups(filtered_adata, n_genes=25, sharey=False, save=f"{cell_type}_{condition_1}_VS_{condition_2}.png")

    # sc.pl.rank_genes_groups_dotplot(filtered_adata, save=f"{cell_type}_{condition_1}_VS_{condition_2}.png")
    p_val = filtered_adata.uns["rank_genes_groups"]["pvals_adj"][condition_1]
    log_fc = filtered_adata.uns["rank_genes_groups"]["logfoldchanges"][condition_1]
    gene_names = filtered_adata.uns["rank_genes_groups"]["names"][condition_1]

    df = pd.DataFrame({
        'gene_name': gene_names,
        f'{task_name}_{condition_1}_VS_{condition_2}_log_fc': log_fc,
        f'{task_name}_{condition_1}_VS_{condition_2}_p_val': p_val
    })
    df = df.sort_values("gene_name")
    df = df.set_index("gene_name")

    return df




def compute_pairwise_degs_all(adata, save_path):

    all_genotypes = np.unique(adata.obs["genotype_annotation"])
    exclude_genotype = ["AAV_ACBE3", "mixed", "unedited", "unknown"]
    all_genotypes = [i for i in all_genotypes if i not in exclude_genotype]

    
    all_dfs_subset = []
    # subset = adata[(adata.obs["Donor_type"]==donor_type) & (adata.obs["cell_type"]==cell_type)]


    all_genotypes = np.unique(adata.obs["genotype_annotation"])
    exclude_genotype = ["AAV_ACBE3", "mixed", "unedited", "unknown"]
    all_genotypes = [i for i in all_genotypes if i not in exclude_genotype]


    for cond1 in all_genotypes:
        if (adata.obs["genotype_annotation"] == cond1).sum()<5:
            continue
        for cond2 in ["AAV_ACBE3"]:
            result_df = compute_marker_genes_and_plot(adata, condition_1=cond1, condition_2=cond2, task_name=f"all")
            all_dfs_subset.append(result_df)
    

    if not os.path.exists(save_path):
        # Create the directory, including any intermediate directories
        os.makedirs(save_path)
    final_subset = pd.concat(all_dfs_subset, axis=1)
    final_subset.to_csv(f"{save_path}/genotype_pairwise_probe_deg.csv")

    
compute_deg(input_path = "/mnt/data/project/25_02_15_stag_analysis/merge_annotate_result/TNRC18_Th1/Annotated_phenotype_hypr_seq_probe.h5ad",
    save_path="/mnt/data/project/25_02_15_stag_analysis/analysis_result/TNRC18_Th1/probe_deg")
