In [None]:
def test_isoform1_DIU_between_alleles(adata, layer="unique_counts", test_condition="control", inplace=True):
    """
    Test if alleles differnet isoform usage and store results in AnnData object.

    Parameters
    -----------
    adata : AnnData
        AnnData object containing expression data
    layer : str, optional
        Layer containing count data (default: "unique_counts")
    test_condition : str, optional
        Variable column name containing condition for testing within (default: "control")
    inplace : bool, optional
        Whether to modify the input AnnData object or return a copy (default: True)

    Returns
    --------
    AnnData or None
        If inplace=False, returns modified copy of AnnData; otherwise returns None
        Results are stored in:
        - adata.uns['isoform_ratio_test']: Complete test results as DataFrame
        - adata.var['isoform_ratio_pval']: P-values for each allele
        - adata.var['isoform_ratio_FDR']: FDR-corrected p-values for each allele
    pd.DataFrame
        Results of statistical tests for each syntelog
    """
    import pandas as pd
    import numpy as np
    import re
    from statsmodels.stats.multitest import multipletests
    from isotools._transcriptome_stats import betabinom_lr_test
    from anndata import AnnData

    # Validate inputs
    if not isinstance(adata, AnnData):
        raise ValueError("Input adata must be an AnnData object")

    # Check if layer exists
    if layer not in adata.layers:
        raise ValueError(f"Layer '{layer}' not found in AnnData object")


    # Work on a copy if not inplace
    if not inplace:
        adata = adata.copy()

    # Get counts and metadata
    counts = adata.layers[layer].copy()  # Create a copy to avoid modifying original


    # Check for syntelog IDs
    if "Synt_id" not in adata.var:
        raise ValueError("'Synt_id' not found in adata.var")
    synt_ids = adata.var["Synt_id"]

    if "haplotype" not in adata.var:
        raise ValueError("'haplotype' not found in adata.var")
    haplotypes = adata.var["haplotype"]

    # Check for transcript IDs
    if not adata.var_names.any():
        raise ValueError("'transcript_id' not found in adata.var_names")
    transcript_ids = adata.var_names

    # Check conditions
    if test_condition not in adata.obs['condition'].unique() and test_condition != "all":
        raise ValueError(f"Condition '{test_condition}' not found in adata.obs['condition']")



    unique_synt_ids = np.unique(synt_ids)

    # Prepare results dataframe
    results = []

    # Create empty arrays for storing p-values in adata.obsm
    pvals = np.full(adata.n_vars, np.nan)
    fdr_pvals = np.full(adata.n_vars, np.nan)
    ratio_diff = np.full(adata.n_vars, np.nan)

    # Create empty arrays for mean ratios per allele
    mean_ratio_cond1 = np.full(adata.n_vars, np.nan)
    mean_ratio_cond2 = np.full(adata.n_vars, np.nan)

    # Track progress
    total_syntelogs = len(unique_synt_ids)
    processed = 0

    # Process each syntelog
    for synt_id in unique_synt_ids:
        processed += 1
        if processed % 100 == 0:
            print(f"Processing syntelog {processed}/{total_syntelogs}")

        # Find alleles (observations) belonging to this syntelog
        allele_indices_all = np.where(synt_ids == synt_id)

        # Get the first element which contains indices
        # find the allele_indices that has ".1." in the transcript_id for each allele 
        allele_indices = [i for i in allele_indices_all[0] if ".1." in transcript_ids[i]]

        # Skip if fewer than 2 alleles found (need at least 2 for ratio testing)
        if len(allele_indices_all[0]) < 2:
            continue

        allele_counts = []
        condition_total = []
        transcript_id = []
        haplotype = []
  

        for allele_idx, allele_pos in enumerate(allele_indices):
        


            # Get samples for this condition
            if test_condition == "all":
                condition_indices = np.arange(counts.shape[0])
            else:
                # Get samples for this condition
                condition_indices = np.where(adata.obs['condition'] == test_condition)[0]

            # Extract counts for these alleles and samples
            condition_counts = counts[np.ix_(condition_indices, allele_indices)]

            # Get indices for all isoforms of this allele
            allele_indices_haplotype = np.where(haplotypes == haplotypes.iloc[allele_indices[allele_idx]])[0]

        
            # Sum across isoforms for this allele and syntelog
            
            total_counts = np.sum(counts[np.ix_(condition_indices, allele_indices_haplotype)], axis=1)
            # Total coounts are the sum of all alleles and isoforms for this syntelog

  

            # Append arrays for total counts
            condition_total.append(total_counts)

            # Append array for this specific allele's counts
            allele_counts.append(condition_counts[:,allele_idx])

            # Get the transcript ID for this allele
            transcript_id.append(transcript_ids[allele_pos])
            haplotype.append(adata.var['haplotype'].iloc[allele_indices[allele_idx]])

            # generate balanced allele counts based on condition total counts
            # balanced counts need to be integers for the test
        print(allele_counts[0:2])
        print(condition_total)
        # add the total counts again for the balanced counts

            
        # Run the beta-binomial likelihood ratio test
        try:
            test_result = betabinom_lr_test(allele_counts[0:2], condition_total[0:2])
            p_value, ratio_stats = test_result[0], test_result[1]
            print(p_value)
            print(ratio_stats)
            # if p_value is np.nan:
            # print(allele_counts, condition_total)
            # Calculate absolute difference in mean ratios between conditions
            ratio_difference = abs(ratio_stats[0] - ratio_stats[2])
        except Exception as e:
            print(f"Error testing syntelog {synt_id}, allele {allele_idx}: {str(e)}")
            continue

        # Get transcript ID and parse allele info
        
        print(transcript_id[0:2])

        
        # Extract allele number from haplotype
        print(haplotype)
        try:
            allele_match = re.search(r'hap(\d+)', haplotype)  # Capture the number
            if allele_match:
                allele_num = allele_match.group(1)  # Get the captured number directly
            else:
                allele_num = f"{allele_idx+1}"  # Fallback if regex fails
                print(f"No match found, using fallback: {allele_num}")
        except Exception as e:
            print(f"Error: {e}")
            allele_num = f"{allele_idx+1}"  # Fallback if any error occurs


        # Store p-value in the arrays we created
        pvals[allele_pos] = p_value
        ratio_diff[allele_pos] = ratio_difference
        mean_ratio_cond1[allele_pos] = ratio_stats[0]
        mean_ratio_cond2[allele_pos] = ratio_stats[2]

           # Store results for each replicate
            #for replicate in range(len(allelic_ratios[unique_conditions[0]])):
        results.append({
                    'Synt_id': synt_id,
                    'allele': allele_num,
                    'transcript_id': transcript_id,
                    'p_value': p_value,
                    'ratio_difference': ratio_difference,
                    'n_alleles': len(allele_indices),
                    f'ratios_{test_condition}_mean': ratio_stats[0]
                   # f'ratios_rep_{test_condition}': allelic_ratios
        })

    # Convert results to DataFrame
    results_df = pd.DataFrame(results)
    print(results_df)

    # # Multiple testing correction if we have results
    # if len(results_df) > 0:
    #     # PROBLEM: p_vale is nan sometimes, replace with 1 for now
    #     results_df['p_value'] = results_df['p_value'].fillna(1)
    #     results_df['FDR'] = multipletests(results_df['p_value'], method='fdr_bh')[1]
    #     results_df = results_df.sort_values('p_value')

    #     # Map FDR values back to the individual alleles
    #     # Group by transcript_id and take the first FDR value (they should be the same for all replicates)
    #     fdr_map = results_df.groupby('transcript_id')['FDR'].first().to_dict()

    #     # Update the FDR array
    #     for i, transcript_id in enumerate(transcript_ids):
    #         if transcript_id in fdr_map:
    #             fdr_pvals[i] = fdr_map[transcript_id]

    # # Store results in the AnnData object
    # adata.uns['allelic_ratio_test'] = results_df
    # adata.var['allelic_ratio_pval'] = pvals
    # adata.var['allelic_ratio_FDR'] = fdr_pvals
    # adata.var['allelic_ratio_difference'] = ratio_diff
    # adata.var[f'allelic_ratio_mean_{test_condition}'] = mean_ratio_cond1
    # adata.var[f'allelic_ratio_mean_{test_condition}'] = mean_ratio_cond2

    # # Group by Synt_id and take mininum FDR value and max ratio difference
    # grouped_results = results_df.groupby('Synt_id').min("FDR")
    # grouped_results= results_df.groupby('Synt_id').agg({
    # 'FDR': 'min',
    # 'ratio_difference': 'max'  # Assuming this is the correct column name
    #     })
    # # Print summary
    # significant_results = grouped_results[(grouped_results['FDR'] < 0.005) & (grouped_results['ratio_difference'] > 0.1)]
    # print(f"Found {len(significant_results)} from {len(grouped_results)} syntelogs with at least one significantly different allele (FDR < 0.005 and ratio difference > 0.05)")

    # # Return AnnData object if not inplace
    # if not inplace:
    #     return adata
    # else:
    #     return results_df


