In [None]:
# This script takes the Hypr-seq adata and the loom file as input
# And output an scRNA-seq data with detailed annotations

import scanpy as sc
import numpy as np
import pandas as pd

import argparse
import seaborn as sns
import sys
import loompy
import logging
import configparser
import anndata
import matplotlib.pyplot as plt
import re
import matplotlib.patches as mpatches
import os
import argparse
import copy
import h5py

# Configure logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)

# Read and filter the Hypr-seq anndata
def filter_hypr(adata, probe_level=False):
    # Remove cells with less than 1000 UMIs per cell
    sc.pp.filter_cells(adata, min_counts=1000)
    # merge the probe name, making the cell by probe matrix to cell by gene matrix
    # Splitting at the underscore to separate gene names from probe numbers
    if not probe_level:
        gene_names = [name.split('_')[0] for name in adata.var_names]
    else:
        gene_names = [name for name in adata.var_names]

    # Add the gene names as a new column in the DataFrame of the .var slot
    adata.var['gene_name'] = gene_names
    # Convert the AnnData to a DataFrame for easier manipulation
    adata_df = pd.DataFrame(adata.X.T, index=adata.var_names, columns=adata.obs_names)

    # Use the gene names to sum the counts
    # Group by the new gene names and sum across columns (probes for the same gene)
    aggregated_data = adata_df.groupby(adata.var['gene_name']).sum()

    # Transpose back to original shape (samples as rows, genes as columns)
    aggregated_data = aggregated_data.T

    # Create new AnnData object with the aggregated data
    adata_aggregated = sc.AnnData(X=aggregated_data)
    # Copy the metadata from the original AnnData object
    adata_aggregated.obs = adata.obs.copy()
    # Optionally, copy over any relevant .uns data (unsupervised annotations, such as PCA, neighbors, etc.)
    #adata_aggregated.uns = adata.uns.copy()

    return adata_aggregated


# function reads the loomfile downloaded from Tapestri portal
def read_tapestri_loom(filename):
    """
    Read data from MissionBio's formatted loom file.

    Parameters
    ----------
    filename : str
        Path to the loom file (.loom)

    Returns
    -------
    anndata.AnnData
        An anndata object with the following layers:
        adata.X: GATK calls
        adata.layers['e']: reads with evidence of mutation
        adata.layers['no_e']: reads without evidence of mutation
    """
    loom_file = loompy.connect(filename)

    variant_names, amplicon_names, chromosome, location = (
        loom_file.ra['id'], loom_file.ra['amplicon'], loom_file.ra['CHROM'], loom_file.ra['POS'])

    barcodes = [barcode.split('-')[0] for barcode in loom_file.ca['barcode']]
    adata = anndata.AnnData(np.transpose(loom_file[:, :]), dtype=loom_file[:, :].dtype)
    adata.layers['e'] = np.transpose(loom_file.layers['AD'][:, :])
    adata.layers['no_e'] = np.transpose(loom_file.layers['RO'][:, :])
    adata.var_names = variant_names
    adata.obs_names = barcodes
    adata.varm['amplicon'] = amplicon_names
    adata.varm['chrom'] = chromosome
    adata.varm['loc'] = location

    loom_file.close()

    return adata

# Optional function, visualize the barcode distribution
def barcode_rank_plot(adata, minimum=0, xmax=None):

    # Sum the UMIs for each cell
    cell_umi_counts_all = adata.X.sum(axis=1)
    cell_idxs = np.argwhere(cell_umi_counts_all > minimum)
    cell_umi_counts = cell_umi_counts_all[cell_idxs]

    # Convert to numpy array if it's not already
    cell_umi_counts = np.array(cell_umi_counts).flatten()

    # Sort the UMI counts in descending order for the rank plot
    sorted_umi_counts = np.sort(cell_umi_counts)[::-1]

    plt.figure(figsize=(5, 4), dpi=150)
    sns.lineplot(x=range(1, len(sorted_umi_counts) + 1), y=sorted_umi_counts)
    plt.xlabel('Barcode Rank')
    plt.ylabel('UMI Count')
    plt.title('Cell Barcode Rank Plot')
    plt.yscale('log')  # Log scale for better visualization
    plt.xscale('log') # Log scale
    if xmax is not None:
        plt.xlim(0, xmax)
    plt.show()


# Function for finding the intersecting barcodes between the modalities
# We can drop the idea of mudata for simplity
def find_intersecting_and_filter(adata_hypr, adata_loom):
    """
    Find the intersecting barcode, reorder, and return new AnnData objects for both datasets.
    """
    # Find intersecting barcodes
    obs_1, obs_2 = adata_hypr.obs_names, adata_loom.obs_names
    cmn_barcodes, idx_1, idx_2 = np.intersect1d(obs_1, obs_2, return_indices=True)

    # Logging the number of barcodes
    logger.info(f"Found {len(obs_1)} barcodes in modality hypr seq")
    logger.info(f"Found {len(obs_2)} barcodes in modality loom file")
    logger.info(f"Found {len(idx_1)} intersecting barcodes")

    # Subset and reorder both datasets based on the intersecting indices
    adata_hypr_new = adata_hypr[idx_1, :]
    adata_loom_new = adata_loom[idx_2, :]

    # # Ensure the order of barcodes is the same in both datasets
    # adata_hypr_new = adata_hypr_new[cmn_barcodes, :]
    # adata_loom_new = adata_loom_new[cmn_barcodes, :]

    return adata_hypr_new, adata_loom_new
    

# Step 1, determine the germline mutation from the loom file
# We envision that if certain mutation appeared too many times in the dataset
# (e.g., more than 25% cells have the homo mutation, we call it homo germline)
def call_germline_mutations(adata_loom, cutoff=0.25):

    fraction_het = np.mean(adata_loom.X == 1, axis=0)
    het_germline = adata_loom.var_names[fraction_het > cutoff]

    fraction_hom = np.mean(adata_loom.X == 2, axis=0)
    hom_germline = adata_loom.var_names[fraction_hom > cutoff]

    return list(het_germline) + list(hom_germline)


# For AAV control, we may ignore mixed cells and bystander effect
def call_AAV_control_edits(adata_loom, config_path='AAV_config.ini'):
    """
    Input: Tapestri loom file and mutation configuration 
    Current we only accept 1 type of AAV control (should be a continuous region)
    Output: All the AAV edits found in the loom file. 

    This function will not modify adata_loom
    """
    # Read configuration settings
    config = configparser.ConfigParser()
    config.read(config_path)
    
    chrom_name = config.get('AAV_control', 'chrom_name')
    start = config.getint('AAV_control', 'start')
    end = config.getint('AAV_control', 'end')
    variant_alleles = config.get('AAV_control', 'variant_alleles')
    
    # Parsing variant names
    variant_names = np.asarray(adata_loom.var_names.values)
    chrom = [name.split(':')[0] for name in variant_names]
    loc = [int(name.split(':')[1]) for name in variant_names]
    edit_type = [name.split(':')[2] for name in variant_names]

    control_editing = []
    complement_rule = {'C': 'G', 'G': 'C', 'A': 'T', 'T': 'A'}
    try:
        reverse_variant_alleles = complement_rule[variant_alleles[0]] + "/" + complement_rule[variant_alleles[-1]]
    except:
        raise ValueError(f"The variant alleles is not supported. Current variant alleles is {variant_alleles}")
    # Identify control editing cells
    for i in range(len(chrom)):
        if chrom[i] == chrom_name and start <= loc[i] < end+1:
            if edit_type[i] in [variant_alleles, reverse_variant_alleles]:
                control_editing.append(variant_names[i])

    return control_editing


def call_AAV_del_edits(adata_loom, config_path):
    """
    Input: Tapestri loom file and mutation configuration 
    Current we only accept 1 type of AAV control (should be a continuous region)
    Output: All the AAV edits found in the loom file. 
    """
    # Read configuration settings
    config = configparser.ConfigParser()
    config.read(config_path)
    
    chrom_name = config.get('AAV_del', 'chrom_name')
    start = config.getint('AAV_del', 'start')
    end = config.getint('AAV_del', 'end')
    
    # Parsing variant names
    variant_names = np.asarray(adata_loom.var_names.values)
    chrom = [name.split(':')[0] for name in variant_names]
    loc = [int(name.split(':')[1]) for name in variant_names]
    edit_type = [name.split(':')[2] for name in variant_names]

    del_editing = []
    
    # Identify  editing cells
    for i in range(len(chrom)):
        if chrom[i] == chrom_name and start <= loc[i] < end+1:
            if edit_type[i][-1] == "*":
                del_editing.append(variant_names[i])

    return del_editing


def call_AAV_cells(adata_hypr, adata_loom, AAV_editing):
    """
    Annotate the hypr adata (scRNA-seq data) by the AAV_edit found in call_AAV_edits.
    We do not perform annotation in the function to ensure consistency
    """
    all_idx = []
    for edit in AAV_editing:
        # get the idx of the edit we found
        idx = adata_loom.var_names.get_loc(edit)
        # if the value is 1/2 (hete/homo), we annotate the cell as control edit cell
        het_idx, hom_idx = [np.flatnonzero(adata_loom.X[:, idx] == i) for i in (1, 2)]
        AAV_edit_idx = np.concatenate([het_idx, hom_idx])
        AAV_cells_barcode_loom = adata_loom.obs_names[AAV_edit_idx]
        # find the AAV cells index in that of the adata_hypr
        _, idx1, _ = np.intersect1d(adata_hypr.obs_names, AAV_cells_barcode_loom, return_indices=True)
        all_idx.append(idx1)
        
    return np.unique(np.concatenate(all_idx))



def get_nearby_variants(
    variant_names, # List of all variant in the loom file
    target_loc, # loci of interest
    germline_amplicons, # List of Germline mutations 
    window_size # window size
):
    """
        Get nearby variants within a window size of the target_loc.
        Args:
            variant_names (list): List of variant names.
            target_loc (str): Target loci (e.g. chr1:123456:A/G).
            germline_amplicons (list): List of germline amplicons.
            window_size (int): Window size.
        Returns:
            list: List of nearby variants.
    """
    chrom_t, loc_t, _ = target_loc.split(":")
    nearby_variants = []
    for variant in variant_names:
        chrom, loc, edit_type = variant.split(":")
        is_valid_edit_type = not edit_type.endswith("*") and re.match(
            r"^[A-Za-z]/[A-Za-z]$", edit_type
        )
        is_not_germline = variant not in germline_amplicons
        if (
            chrom == chrom_t
            and (int(loc_t) - window_size) <= int(loc) <= (int(loc_t) + window_size)
            and is_valid_edit_type
            and is_not_germline
        ):
            nearby_variants.append(variant)

    return nearby_variants



def plot_mutant_types(target_loc, new_matrix, mt_type, snp_name, save_path, index=""):
    """
    Plot the nearby genotype heatmap
    """
    mutation_counts = new_matrix.apply(np.count_nonzero, axis=0)
    plt.figure(figsize=(10, 6))
    bars = plt.bar(mutation_counts.index, mutation_counts.values)
    for i, idx in enumerate(mutation_counts.index):
        plt.text(
            bars[i].get_x() + bars[i].get_width() / 2,
            bars[i].get_height(),
            str(mutation_counts.values[i]),
            ha="center",
            va="bottom",
        )
    if target_loc in mutation_counts.index:
        bars[mutation_counts.index.tolist().index(target_loc)].set_color("r")
    red_patch = mpatches.Patch(color="red", label="Target Loci")
    plt.legend(handles=[red_patch])

    # we need to rename the target_loc to ensure we do not introduce extra path problems
    target_loc = target_loc.replace("/", "-")


    plot_title = (
        f"Mutation Counts: {mt_type} - {target_loc} ({index}) with total cells {len(new_matrix)}"
    )
    plt.title(plot_title)
    plt.xticks(rotation=90)
    plt.xlabel("")
    plt.ylabel("Number of Mutated Cells")

    save_index = f"{snp_name}_{mt_type}_mutant_counts_{index}.pdf"
    plt.savefig(os.path.join(save_path, save_index), bbox_inches="tight")
    plt.close()

    # Plot a cluster map of matrix but cluster rows only.
    sns.clustermap(
        new_matrix,
        cmap="YlGnBu",
        row_cluster=True,
        col_cluster=False,
        figsize=(10, 10),
    )
    plt.title(
        f"Mutation Matrix: {mt_type} - {snp_name} aka {target_loc} with shape {new_matrix.shape}"
    )
    plt.xlabel("Variant")
    plt.ylabel("Cell Barcode")
    save_index = f"{snp_name}_{mt_type}_mutant_matrix_{index}.pdf"
    plt.savefig(os.path.join(save_path, save_index), bbox_inches="tight")
    plt.close()




def process_bystander_cells(target_loc, cells, nearby_variants, adata_loom):
    """
    Identify which cells have bystander editing.
    Args:
        target_loc (str): Target loci (e.g. chr1:123456:A/G).
        cells (list): List of cells (that supposedly have the mutation).
        nearby_variant: adata that includes only the nearby cells
    Returns:
        list: List of pure cells.
        pd.DataFrame: New matrix.
        pd.DataFrame: Matrix with false amplicons.
    """
    nearby_adata = adata_loom[cells, nearby_variants].copy()
    new_matrix = pd.DataFrame(
        nearby_adata.X, index=nearby_adata.obs_names, columns=nearby_adata.var_names
    )
    # Change all 3 to 0 in the new matrix, but excluding the target loci column.
    for col in new_matrix.columns:
        if col != target_loc:
            new_matrix[col] = new_matrix[col].apply(
                lambda val: 0 if val == 3 else val
            )

    # Remove the target loci column to generate a bystander position only matrix.
    matrix_without_target_loci = new_matrix.loc[
        :, new_matrix.columns != target_loc
    ].copy()
    bystander_sum_for_every_cell = matrix_without_target_loci.sum(axis=1)
    # Remove cells that have a row sum of 0, or there's no bystander editing.
    matrix_without_target_loci = matrix_without_target_loci.loc[
        bystander_sum_for_every_cell != 0
    ]

    # If there's no bystander editing, return the original cells.
    if len(matrix_without_target_loci) == 0:
        # Set the bystander as the empty set
        return set(cells), set(), new_matrix

    # Everything that's left is a bystander cell.
    pure_cells = set(cells) - set(matrix_without_target_loci.index)
    return pure_cells, set(matrix_without_target_loci.index), new_matrix



def call_single_loci_cells(
    target_loc, # the loci of interest
    adata_loom, # the loom matrix to retrieve info
    window_size, # the window size 
    germline_amplicons, # The germline mutations
    snp_name, # the name of the loci
    plot_path, # the path to save the figures
):
    """
    Process a single target loci. 
    We will annotate all cells that include the loci
    The values represent the mutant type:
        0: Background
        1: Pure heterozygous
        2: Heterozygous with bystander editing
        3: Pure homozygous
        4: Homozygous with bystander editing
    """
    # Find the idx of the loci
    idx = adata_loom.var_names.get_loc(target_loc)
    # Annotate cells that are background, het, and hom based on the loom matrix
    bkg_cells, het_cells, hom_cells = [
            adata_loom.obs_names[np.flatnonzero(adata_loom.X[:, idx] == i)] for i in (0, 1, 2)
            ]
    # Find all nearby variant of the give loci
    variant_names = adata_loom.var_names
    nearby_variants = get_nearby_variants(
            variant_names, target_loc, germline_amplicons, window_size
    )
    if target_loc == "chr5:134138057:A/G":
        if 'chr5:134138060:A/G' in nearby_variants:
            nearby_variants.remove("chr5:134138060:A/G")
        if 'chr5:134138062:A/G' in nearby_variants:
            nearby_variants.remove("chr5:134138062:A/G")

    if target_loc == "chr5:134138060:A/G":
        if 'chr5:134138057:A/G' in nearby_variants:
            nearby_variants.remove("chr5:134138060:A/G")
        if 'chr5:134138062:A/G' in nearby_variants:
            nearby_variants.remove("chr5:134138062:A/G")

    if target_loc == "chr5:134138062:A/G":
        if 'chr5:134138057:A/G' in nearby_variants:
            nearby_variants.remove("chr5:134138060:A/G")
        if 'chr5:134138060:A/G' in nearby_variants:
            nearby_variants.remove("chr5:134138062:A/G")
    

    het_pure, het_bystander, het_matrix = process_bystander_cells(
            target_loc, het_cells, nearby_variants, adata_loom
    )
    hom_pure, hom_bystander, hom_matrix = process_bystander_cells(
            target_loc, hom_cells, nearby_variants, adata_loom
    )
    plot_mutant_types(target_loc, het_matrix, "Heterozygous", snp_name, plot_path, "1-0")
    plot_mutant_types(target_loc, hom_matrix, "Homozygous", snp_name, plot_path, "1-0")

    return list(het_bystander), list(hom_bystander), list(het_pure), list(hom_pure)

    # adata_hypr.loc[hom_pure, "genotype"] = f"{target_loc}_homo_pure"
    # adata_hypr.loc[het_pure, "genotype"] = f"{target_loc}_hete_pure"


def annotate_cells(adata):
    """
    Annotates each cell in adata.obs based on the SNP, control, and del columns, according to specified rules.
    
    Parameters:
    - adata: AnnData object containing the obs DataFrame with binary columns for each condition.
    
    Returns:
    - Annotations are stored in a new column 'annotation' in adata.obs.
    """
    
    # List of condition columns, based on the presence of SNPs, control, and del columns in adata.obs
    snp_columns = [col for col in adata.obs.columns if any(sub in col for sub in ["het_bystander", "het_pure", "hom_bystander", "hom_pure"])]
    control_column = "AAV_control"
    del_column = "AAV_del"
    
    # Define the function for annotating a single cell
    def annotate_row(row):
        # Check SNP-related rules
        snp_active = row[snp_columns].sum()  # Count number of active SNP columns (True values)
        
        if snp_active > 1:  # More than one SNP column is True
            return "mixed"
        
        if row[control_column] and snp_active == 1:  # SNP and control active
            # Find which SNP column is active
            snp_name = row[snp_columns][row[snp_columns] == True].index[0]
            return "mixed"
        
        if row[del_column] and snp_active >= 1:  # SNP and del active
            return "mixed"
        
        if row[control_column] and snp_active == 0:  # Only control active
            return "AAV_control"
        
        if row[del_column] and snp_active == 0:  # Only del active
            return "AAV_del"
        
        if snp_active == 1:  # Single SNP active, no control or del
            snp_name = row[snp_columns][row[snp_columns] == True].index[0]
            return snp_name
        
        # If no columns are active
        if snp_active == 0 and not row[control_column] and not row[del_column]:
            return "unedited"
        
        # Default case, though we should not reach here
        return "unknown"
    # import pdb; pdb.set_trace()
    # Apply the annotation function to each row in a vectorized manner
    adata.obs['genotype_annotation'] = adata.obs.apply(annotate_row, axis=1)


def annotate_genotype(adata_loom, 
                    adata_hypr, 
                    config_path,
                    save_path="./",
                    plot_path="./"):
    """
    Annotate the adata_hypr, such that in the adata_hypr.obs, we have:
    Homo pure for SNP 1, 2, 3, …
    Homo bystander for SNP 1,2, 3, …
    Hete pure for SNP 1, 2, 3, …
    Hete bystander for SNP 1, 2, 3…
    Mixed cells
    AAVs: AAV_control, AAV_del
    Unedited. 
    """

    # Perform the annotation step. We need:
    # 1. Determine the germline mutations
    # 2. Annotate the AAV_control cells
    # 3. Annotate the AAV_del_control cells
    # 4. For each SNP we are interested in, find if it has bystander effect
    # 5. Determine the bystander mutations
    # 6.  Annotate all mixed cells. For example, for 2 SNPs we are interested in, 
    #  if they are co-edited, we need to count the number. We need to annotate it as mixed.
    

    # Step 1, determine the germline mutations
    germline_amplicons = call_germline_mutations(adata_loom, cutoff=0.25)

    # Step 2, annotate the AAV_control cells
    # 2.1 Find the control editing locis 
    AAV_control_editing = call_AAV_control_edits(adata_loom, config_path=config_path)
    # 2.2 Add a column to the adata_hypr obs and then perform the annotation only here
    adata_hypr.obs["AAV_control"] = False
    AAV_control_cell_idx = call_AAV_cells(adata_hypr, adata_loom, AAV_editing=AAV_control_editing)
    adata_hypr.obs.iloc[AAV_control_cell_idx, adata_hypr.obs.columns.get_loc("AAV_control")] = True

    # Step 3, annotate the AAV_del_control_cells
    # 3.1 Find the del editing locis
    AAV_del_editing = call_AAV_del_edits(adata_loom, config_path=config_path)
    # 3.2 Add a column to the adata_hypr and then perform the annotation only here
    adata_hypr.obs["AAV_del"] = False
    AAV_del_cell_idx = call_AAV_cells(adata_hypr, adata_loom, AAV_editing=AAV_del_editing)
    adata_hypr.obs.iloc[AAV_del_cell_idx, adata_hypr.obs.columns.get_loc("AAV_del")] = True

    # Step 4/5, Enumerate all other variants we are interested in
    # Initialize the config parser
    config = configparser.ConfigParser()
    # Read the config file
    config.read(config_path)
    # Retrieve and return all section names
    sections = config.sections()
    exclude_sections = ["AAV_control", "AAV_del"]
    SNPs = [section for section in sections if section not in exclude_sections]

    # Each snp is the section name we defined in the config file
    # Perform the annotation here
    annotation = ["het_bystander", "het_pure", "hom_bystander", "hom_pure"]

    for snp in SNPs:
        chrom_name = config.get(snp, 'chrom_name')
        locus = config.getint(snp, 'locus')
        alleles = config.get(snp, 'variant_alleles')
        complement_rule = {'C': 'G', 'G': 'C', 'A': 'T', 'T': 'A'}
        try:
            complement_variant_alleles = complement_rule[alleles[0]] + "/" + complement_rule[alleles[-1]]
        except:
            raise ValueError(f"The variant alleles is not supported. Current variant alleles is {alleles}")
        # Transfer back to the original version
        target_loc = ":".join([chrom_name, str(locus), alleles])
        complement_target_loc = ":".join([chrom_name, str(locus), complement_variant_alleles])
        
        # Find which allele is the one we want, the original version or the complement one
        if target_loc in adata_loom.var_names:
            target_loc = target_loc
        elif complement_target_loc in adata_loom.var_names:
            target_loc = complement_target_loc
        else:
            raise ValueError(f"Both the mutation {target_loc} and the reversed one {complement_target_loc} is not found in the loom file")


        # Get the index of the cells
        het_bystander, hom_bystander, het_pure, hom_pure = call_single_loci_cells(target_loc, adata_loom, window_size=10, 
            germline_amplicons=germline_amplicons, snp_name = snp, plot_path=plot_path)
        

        for i, idx_i in enumerate([het_bystander, het_pure, hom_bystander, hom_pure]):
            adata_hypr.obs[f"{snp}_{annotation[i]}"] = False
            adata_hypr.obs.loc[idx_i, f"{snp}_{annotation[i]}"] = True

        # adata_hypr.obs[f"{snp}"] = adata_hypr.obs[f"{snp}_het_all"] | adata_hypr.obs[f"{snp}_hom_all"]

    # Step 7, for the cells that has genotypes, we need to ensure that they are not mixed cells
    annotate_cells(adata_hypr)

    return adata_hypr, adata_loom
    



def main():
    # Create the parser
    parser = argparse.ArgumentParser(description="Filter, Intersect and Annotate \
     the genotype information to the phenotype RNA-seq count matrix")
    
    # Add arguments for file paths
    parser.add_argument('--loom_path', type=str, required=True, help='The path to the loom file -- genotype')
    parser.add_argument('--hypr_path', type=str, required=True, help='The path to the hypr file -- phenotype')
    parser.add_argument('--config_path', type=str, required=True, help='The path to the variant configuration file')
    parser.add_argument('--save_path', type=str, required=True, help='The path to save annotated anndata')
    parser.add_argument('--plot_path', type=str, required=True, help='The path to save genotype clustermap figure')
    parser.add_argument('--probe_level', type=int, required=True, help="1 for probe level matrix, other numbers for gene_level")
    
    # Parse the arguments
    args = parser.parse_args()
    

    # read the data and set the config path
    adata_loom = read_tapestri_loom(args.loom_path)
    adata_hypr = sc.read(args.hypr_path)
    # import pdb; pdb.set_trace()
    probe_level = (args.probe_level==1)

    adata_hypr = filter_hypr(adata_hypr, probe_level=probe_level)

    config_path = args.config_path
    save_path = args.save_path
    plot_path = args.plot_path

    if not os.path.exists(save_path):
        # Create the directory, including any intermediate directories
        os.makedirs(save_path)
    if not os.path.exists(plot_path):
        # Create the directory, including any intermediate directories
        os.makedirs(plot_path)

    # filter and intersect between hypr matrix and loom matrix
    adata_hypr, adata_loom = find_intersecting_and_filter(adata_hypr, adata_loom)
    
    # Annotate the filtered data
    adata_hypr, adata_loom = annotate_genotype(adata_loom=adata_loom, adata_hypr=adata_hypr, config_path=config_path, save_path=save_path, plot_path=plot_path)
    if probe_level:
        hypr_name = f"{save_path}/Annotated_phenotype_hypr_seq_probe.h5ad"
    else:
        hypr_name = f"{save_path}/Annotated_phenotype_hypr_seq.h5ad"
    sc.write(hypr_name, adata_hypr)
    sc.write(f"{save_path}/genotype_loom.h5ad", adata_loom)



if __name__ == "__main__":
    # example usage:
    # python merge_and_annotate.py \
    # --loom_path="data/TCF7_CD4_8/TCF7_CD4-8.cells.loom" \
    # --hypr_path="data/TCF7_CD4_8/TCF7_CD4-8_hypr.h5ad" \
    # --config_path="TCF7.ini" --save_path="result/TCF7_CD4_8/"
    # --plot_path="analysis_result/genotype_by_loc/CD4_8"
    main() 

In [None]:
# Functions For conducting the DEG analysis and plot figures
# They should be generalizable for any datasets that generated through the 
# "merge_and_annotate.py"

import scanpy as sc
import matplotlib.pyplot as plt

import numpy as np 
import pandas as pd 
import seaborn as sns


def compute_marker_genes_and_plot(adata, condition_1, condition_2, task_name=""):
    """
    We need to ensure that the condition_2 is the control case. Otherwise, the 
    saved parameters/values may not so reliable.
    """
    
    # assert condition_2 in ["AAV_control", "unedited", "AAV_del"]
    # import pdb; pdb.set_trace()
    filtered_adata = adata[(adata.obs["genotype_annotation"]==condition_1) | (adata.obs["genotype_annotation"]==condition_2)]

    filtered_adata.obs['genotype_annotation'] = filtered_adata.obs['genotype_annotation'].astype('category')
    # import pdb; pdb.set_trace()
    sc.tl.rank_genes_groups(filtered_adata, "genotype_annotation", method="wilcoxon")
    # sc.pl.rank_genes_groups(filtered_adata, n_genes=25, sharey=False, save=f"{cell_type}_{condition_1}_VS_{condition_2}.png")

    # sc.pl.rank_genes_groups_dotplot(filtered_adata, save=f"{cell_type}_{condition_1}_VS_{condition_2}.png")
    p_val = filtered_adata.uns["rank_genes_groups"]["pvals_adj"][condition_1]
    log_fc = filtered_adata.uns["rank_genes_groups"]["logfoldchanges"][condition_1]
    gene_names = filtered_adata.uns["rank_genes_groups"]["names"][condition_1]

    df = pd.DataFrame({
        'gene_name': gene_names,
        f'{task_name}_{condition_1}_VS_{condition_2}_log_fc': log_fc,
        f'{task_name}_{condition_1}_VS_{condition_2}_p_val': p_val
    })
    df = df.sort_values("gene_name")
    df = df.set_index("gene_name")

    return df



def plot_genotype_gene_clustermap(adata, genes_of_interest, genotypes_of_interest, task_name="", scale=True):
    """
    Plots a clustermap of average gene expression values for each genotype across a subset of genes.
    
    Parameters:
    - adata: AnnData object containing the single-cell RNA-seq data.
    - genes_of_interest: List of gene names (subset of adata.var_names) to include in the plot.
    - genotypes_of_interest: List of genotypes (subset of adata.obs["genotype_annotation"]) to include.
    - scale: Boolean, if True, scales the expression values for better visualization.
    
    Returns:
    - Clustermap of the average gene expression values for the specified genotypes and genes.
    """
    
    # Filter cells by genotypes of interest
    adata_subset = adata[adata.obs["genotype_annotation"].isin(genotypes_of_interest), genes_of_interest]
    
    # Extract expression data and add genotype labels
    gene_expression_df = pd.DataFrame(adata_subset.X.toarray(), 
                                      index=adata_subset.obs_names, 
                                      columns=adata_subset.var_names)
    gene_expression_df["genotype"] = adata_subset.obs["genotype_annotation"].values
    
    # Compute the average expression per gene for each genotype
    avg_expression_df = gene_expression_df.groupby("genotype").mean()
    
    # Optionally scale the data for visualization
    if scale:
        avg_expression_df = (avg_expression_df - avg_expression_df.mean()) / avg_expression_df.std()
    
    # Plot the clustermap
    # plt.figure(figsize=(20, 10))
    sns.clustermap(avg_expression_df, 
                   cmap="vlag", 
                   row_cluster=True, 
                   col_cluster=True,
                   xticklabels=True, 
                   yticklabels=True,
                   figsize=(18, 10))
    import os
    os.makedirs(f"analysis_result/genotype_gene_heatmap/{task_name}",exist_ok=True)
    plt.title(f"Gene Expression Clustermap -- {task_name}")
    plt.tight_layout()
    
    plt.savefig(f"analysis_result/genotype_gene_heatmap/{task_name}/heatmap.svg")
    plt.savefig(f"analysis_result/genotype_gene_heatmap/{task_name}/heatmap.png")
    avg_expression_df.to_csv(f"analysis_result/genotype_gene_heatmap/{task_name}/heatmap.csv")



def plot_genotype_cluster_composition(adata, genotype_column='genotype_annotation', cluster_column='leiden', save_path="./", task_name=""):
    """
    Plots the composition of leiden clusters within each genotype as a stacked bar chart.
    
    Parameters:
    - adata: AnnData object containing the data.
    - genotype_column: str, the column in adata.obs representing genotype annotations.
    - cluster_column: str, the column in adata.obs representing leiden cluster labels.
    """
    
    # Extract the relevant columns from adata.obs
    df = adata.obs[[genotype_column, cluster_column]]
    
    # Count the occurrences of each leiden cluster within each genotype
    composition = df.groupby([genotype_column, cluster_column]).size().unstack(fill_value=0)
    
    # Normalize counts to get the proportion of each leiden cluster within each genotype
    composition = composition.div(composition.sum(axis=1), axis=0)
    
    # Plot the composition as a stacked bar chart
    ax = composition.plot(kind='bar', stacked=True, figsize=(10, 7), colormap="tab20")
    
    # Add labels and title
    plt.title('Leiden Cluster Composition by Genotype')
    plt.xlabel('Genotype')
    plt.ylabel('Proportion of Each Cluster')
    plt.legend(title=cluster_column, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(f"{save_path}/{task_name}_composition.svg")
    plt.savefig(f"{save_path}/{task_name}_composition.png")



if __name__ == "__main__":
    # Here I write simple test for the functions implemented in the file
    adata = sc.read("/home/hanwenka/project/stag_analysis/result/TCF7_CD4_8/Annotated_phenotype_hypr_seq.h5ad")

    df = compute_marker_genes_and_plot(adata, "TCF7_SPA6_hom_bystander", "AAV_control", task_name="all_cell")
    print(df)


In [None]:
# This code should only work for the CD4/8 dataset.

import scanpy as sc
import matplotlib.pyplot as plt

import numpy as np 
import pandas as pd 

from analysis_utils import compute_marker_genes_and_plot, plot_genotype_gene_clustermap, plot_genotype_cluster_composition
import seaborn as sns
import os
from pathlib import Path
from scipy.stats import zscore

# Do the pairwise deg analysis for the cd4/8 dataset

def compute_pairwise_degs_cd4_8(adata):
    # The input adata should be annotated with cd4/8 celltypes in
    # the adata.obs
    assert "cell_type" in adata.obs.columns
    cd4_cell = adata[adata.obs["cell_type"]=="CD4"]
    cd8_cell = adata[adata.obs["cell_type"]=="CD8"]

    all_dfs_cd4 = []
    all_dfs_cd8 = []
    for cond_1_prefix in ['TCF7_rs12659489', 'TCF7_rs57918439', 'TCF7_SPA4', 'TCF7_SPA6']:
        for subfix in ["het_bystander", "het_pure", "hom_bystander", "hom_pure"]:
            cond1 = f"{cond_1_prefix}_{subfix}"
            for cond2 in ["AAV_control", "unedited", "AAV_del"]:
                result_df_cd4 = compute_marker_genes_and_plot(cd4_cell, condition_1=cond1, condition_2=cond2, task_name="CD4")
                result_df_cd8 = compute_marker_genes_and_plot(cd8_cell, condition_1=cond1, condition_2=cond2, task_name="CD8")

                all_dfs_cd4.append(result_df_cd4)
                all_dfs_cd8.append(result_df_cd8)

    final_cd4 = pd.concat(all_dfs_cd4, axis=1)
    final_cd8 = pd.concat(all_dfs_cd8, axis=1)

    final_cd4.to_csv("analysis_result/deg_tables/cd4_genotype_pairwise_deg.csv")
    final_cd8.to_csv("analysis_result/deg_tables/cd8_genotype_pairwise_deg.csv")



def compute_leiden_cluster_degs(filtered_adata):
    sc.tl.rank_genes_groups(filtered_adata, "leiden", method="wilcoxon")

    all_dfs = []
    for cluster_name in ["0", "1", "2", "3", "4", "5"]:
        # import pdb; pdb.set_trace()
        p_val = filtered_adata.uns["rank_genes_groups"]["pvals_adj"][cluster_name]
        log_fc = filtered_adata.uns["rank_genes_groups"]["logfoldchanges"][cluster_name]
        gene_names = filtered_adata.uns["rank_genes_groups"]["names"][cluster_name]

        df = pd.DataFrame({
            'gene_name': gene_names,
            f'{cluster_name}_log_fc': log_fc,
            f'{cluster_name}_p_val': p_val
        })
        df = df.sort_values("gene_name")
        df = df.set_index("gene_name")
        all_dfs.append(df)
    all_dfs = pd.concat(all_dfs, axis=1)

    all_dfs.to_csv("analysis_result/deg_tables/cd4_8_cluster_degs.csv")






def gene_genotype_clustermap(adata, gene_list_path):
    """
    In this code, I will plot the cluster map figure. One axis is the gene, another axis is the genotype
    """

    # read the gene of interest list
    gene_list = pd.read_csv(gene_list_path, header=None)
    gene_list.columns = ["gene"]
    gene_list = gene_list["gene"]

    # make the genotype of interest list
    genotype_list = ["AAV_control"]
    for prefix in ['TCF7_rs12659489', 'TCF7_rs57918439', 'TCF7_SPA4', 'TCF7_SPA6']:
        for subfix in ["het_bystander", "het_pure", "hom_bystander", "hom_pure"]:
            genotype = f"{prefix}_{subfix}"
            genotype_list.append(genotype)
    
    assert "cell_type" in adata.obs.columns
    cd4_cell = adata[adata.obs["cell_type"]=="CD4"]
    cd8_cell = adata[adata.obs["cell_type"]=="CD8"]

    plot_genotype_gene_clustermap(adata, genes_of_interest=gene_list, genotypes_of_interest=genotype_list, task_name="cd4-8_scale")
    # plot_genotype_gene_clustermap(adata, genes_of_interest=gene_list, genotypes_of_interest=genotype_list, task_name="no_scale", scale=False)
    plot_genotype_gene_clustermap(cd4_cell, genes_of_interest=gene_list, genotypes_of_interest=genotype_list, task_name="cd4_scale")
    plot_genotype_gene_clustermap(cd8_cell, genes_of_interest=gene_list, genotypes_of_interest=genotype_list, task_name="cd8_scale")


def plot_umap_figure(adata):
    """
    Plot the umap of the saved adata.
    """
    sc.pl.umap(adata, color="leiden", save="_leiden.png", palette="Set1")
    sc.pl.umap(adata, color=["CD4", "CD8A", "CD8B"], save="_cd4_8A_8B.png", palette="Set1")
    sc.pl.umap(adata, color=["TCF7", "SELL", "CD44"], save="_TCF7_SELL_CD44.png", palette="Set1")
    import shutil
    shutil.move("figures/","CD4_8_umap/")
    shutil.move("CD4_8_umap/", "analysis_result/umap/")


def plot_composition_figure(adata):
    cd4_cell = adata[adata.obs["cell_type"]=="CD4"]
    cd8_cell = adata[adata.obs["cell_type"]=="CD8"]

    save_path = "analysis_result/composition_figures"
    if not os.path.exists(save_path):
        # Create the directory, including any intermediate directories
        os.makedirs(save_path)
    plot_genotype_cluster_composition(cd4_cell, genotype_column="genotype_annotation", cluster_column="leiden", save_path=f"{save_path}", task_name="cd4")
    plot_genotype_cluster_composition(cd8_cell, genotype_column="genotype_annotation", cluster_column="leiden", save_path=f"{save_path}", task_name="cd8")



def plot_cluster_by_gene_heatmap(adata):
    """
    This function plot the cluster by gene matrix plot and save the matrix for other plots
    """
    cd8_cell = adata[adata.obs["cell_type"]=="CD8"]
    Proliferating = ["RRM2", "SLC29A1", "CCNB1", "CDK2", "MYBL2", "H2AX", "LMNB1", "AURKA"]
    naive = ["PECAM1", "CISH", "HAVCR2", "BACH2", "NELL2", "TCF7", "LEF1", "CCR7", "SELL"]
    effector = ["NKG7", "KLRD1", "GZMH", "GZMB", "GZMA", "CCR5", "IFNG", "TBX21", "TIGIT", "BHLHE40"]

    gene_list = Proliferating + naive + effector
    
    # Make sure all genes are in the dataset
    gene_list = [gene for gene in gene_list if gene in adata.var_names]

    grouped_means = cd8_cell[:, gene_list].to_df().groupby(cd8_cell.obs['leiden']).mean()

    sc._settings.ScanpyConfig.figdir = Path("analysis_result/cluster_by_gene_heatmap")
    matrix_result = sc.pl.matrixplot(cd8_cell, var_names=gene_list, groupby="leiden", cmap="Blues", swap_axes=False, dendrogram=False, standard_scale='var', return_fig=True)
    # sc.pl.matrixplot(cd8_cell, var_names=gene_list, groupby="leiden", cmap="Blues", swap_axes=False, dendrogram=False, standard_scale='var', save="cd8.png")
    
    # Save the data
    output_dir = Path("analysis_result/cluster_by_gene_heatmap")
    output_dir.mkdir(parents=True, exist_ok=True)
    scaled_means = grouped_means.copy()
    scaled_means -= scaled_means.min(axis=0)  # Subtract the column-wise minimum
    scaled_means = (scaled_means / scaled_means.max(axis=0)).fillna(0)  # Divide by the column-wise maximum and handle NaNs

    grouped_means.to_csv(output_dir / "cd8_expression_matrix.csv")
    scaled_means.to_csv(output_dir / "cd8_expression_matrix_norm.csv")




def probe_level_deg():
    # the probe cell type is annotated directly using the gene-level clusters
    adata = sc.read("result/TCF7_CD4_8/Annotated_phenotype_hypr_seq_probe.h5ad")
    # do the following 2 for probes only 
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    adata.obs['genotype_annotation'] = adata.obs["genotype_annotation"].astype('category')
    compute_pairwise_degs_cd4_8(adata)


def get_umi_count_per_gene_cell_type_cluster():
    adata_raw = sc.read("/home/hanwenka/project/stag_analysis/result/TCF7_CD4_8/Annotated_phenotype_hypr_seq.h5ad")
    adata_cell_type = sc.read("/home/hanwenka/project/stag_analysis/result/TCF7_CD4_8/annotated_genotype_celltype.h5ad")
    adata_raw.obs["leiden"] = adata_cell_type.obs["leiden"]
    adata_raw.obs["cell_type"] = adata_cell_type.obs["cell_type"]

    grouped_means_leiden = adata_raw.to_df().groupby(adata_raw.obs['leiden']).mean()
    grouped_means_cell_type = adata_raw.to_df().groupby(adata_raw.obs['cell_type']).mean()
    output_dir = Path("analysis_result/mean_umi_clusters_cd_4_8")
    output_dir.mkdir(parents=True, exist_ok=True)
    grouped_means_leiden.to_csv(output_dir / "leiden_expression_matrix.csv")
    grouped_means_cell_type.to_csv(output_dir / "cell_type_expression_matrix.csv")
    # import pdb; pdb.set_trace()



if __name__ == "__main__":
    adata = sc.read("result/TCF7_CD4_8/annotated_genotype_celltype.h5ad")
    #compute_pairwise_degs_cd4_8(adata)
    # compute_leiden_cluster_degs(adata)
    # gene_genotype_clustermap(adata, "/home/hanwenka/project/stag_analysis/internal_files/cd8_heatmap_genes.csv")
    # plot_umap_figure(adata)
    # plot_composition_figure(adata)
    plot_cluster_by_gene_heatmap(adata)
    # get_umi_count_per_gene_cell_type_cluster()
