In [1]:
import anndict as adt
import scanpy as sc
adata = sc.read_h5ad("TSP33_pancreas_annotated_top_10000.h5ad")


In [None]:
adata

In [3]:
def detect_isoform_switching(
    adata,
    ct_col='gpt-3.5-turbo_simplified_ai_cell_type',
    transcript_id_col='transcript_id',
    layer=None,
    min_expression_frac=0.5,  # Minimum fraction to be considered "predominant"
    min_switching_frac=0.2,    # Minimum fraction of CTs that must switch
    min_cells_per_ct=10,       # Minimum cells per cell type to include
    prefix_patterns=None,      # e.g., ['ENST', 'PB'] or None to auto-detect
    verbose=True
):
    """
    Detect genes with isoform switching between different transcript types (e.g., ENST vs PB).
    
    Parameters:
    -----------
    adata : AnnData
        Annotated data object
    ct_col : str
        Column name in adata.obs containing cell type labels
    transcript_id_col : str
        Column name in adata.var containing transcript IDs
    layer : str or None
        Layer to use (None for adata.X)
    min_expression_frac : float
        Minimum fraction of expression from a prefix type to be considered "predominant"
    min_switching_frac : float
        Minimum fraction of cell types that must show switching pattern
    min_cells_per_ct : int
        Minimum number of cells per cell type to include in analysis
    prefix_patterns : list or None
        List of prefix patterns to group by (e.g., ['ENST', 'PB']).
        If None, auto-detects common prefixes.
    verbose : bool
        Print progress information
    
    Returns:
    --------
    switching_df : DataFrame
        DataFrame with columns:
        - gene: gene name
        - dominant_prefix: Most common prefix across all CTs
        - switching_prefix: Prefix that appears in switching CTs
        - n_cell_types: Total number of cell types expressing the gene
        - n_switching_cts: Number of cell types showing switching
        - switching_cts: List of cell types showing switching
        - dominant_cts: List of cell types with dominant prefix
        - switching_frac: Fraction of CTs showing switching
    """
    
    # Helper function to get matrix
    def get_mat(A):
        mat = A.layers[layer] if layer is not None else A.X
        return mat
    
    def sum_over_axis0(A):
        mat = get_mat(A)
        if hasattr(mat, 'sum'):
            result = mat.sum(axis=0)
            if hasattr(result, 'A1'):  # sparse matrix
                return result.A1
            return np.asarray(result).ravel()
        return np.asarray(mat.sum(axis=0)).ravel()
    
    # Sanity checks
    if transcript_id_col not in adata.var.columns:
        raise ValueError(f"adata.var must contain '{transcript_id_col}'")
    if 'gene_name' not in adata.var.columns:
        raise ValueError("adata.var must contain 'gene_name'")
    if ct_col not in adata.obs.columns:
        raise ValueError(f"adata.obs must contain '{ct_col}'")
    
    # Auto-detect prefix patterns if not provided
    if prefix_patterns is None:
        transcript_ids = adata.var[transcript_id_col].astype(str)
        # Extract prefixes (everything before first dot or underscore, or first 3-4 chars)
        prefixes = set()
        for tx_id in transcript_ids.unique()[:1000]:  # Sample first 1000
            if '.' in tx_id:
                prefix = tx_id.split('.')[0]
            elif '_' in tx_id:
                prefix = tx_id.split('_')[0]
            else:
                prefix = tx_id[:4]  # First 4 characters
            if len(prefix) >= 2:
                prefixes.add(prefix)
        prefix_patterns = sorted(list(prefixes))
        if verbose:
            print(f"Auto-detected prefix patterns: {prefix_patterns}")
    
    # Get all unique genes
    all_genes = adata.var['gene_name'].unique()
    if verbose:
        print(f"Analyzing {len(all_genes)} genes...")
    
    # Get cell types with sufficient cells
    ct_counts = adata.obs[ct_col].value_counts()
    valid_cts = ct_counts[ct_counts >= min_cells_per_ct].index.tolist()
    if verbose:
        print(f"Using {len(valid_cts)} cell types with >= {min_cells_per_ct} cells")
    
    switching_results = []
    
    for gene_idx, gene in enumerate(all_genes):
        if verbose and (gene_idx + 1) % 100 == 0:
            print(f"  Processed {gene_idx + 1}/{len(all_genes)} genes...")
        
        # Find transcripts for this gene
        gene_mask = adata.var['gene_name'] == gene
        if gene_mask.sum() == 0:
            continue
        
        # Subset to gene
        adata_gene = adata[:, gene_mask].copy()
        transcript_ids = adata_gene.var[transcript_id_col].astype(str).values
        
        # Group transcripts by prefix
        prefix_groups = {}
        for prefix in prefix_patterns:
            prefix_mask = np.array([tx_id.startswith(prefix) for tx_id in transcript_ids])
            if prefix_mask.sum() > 0:
                prefix_groups[prefix] = prefix_mask
        
        # Need at least 2 different prefix types for switching
        if len(prefix_groups) < 2:
            continue
        
        # Analyze each cell type
        ct_results = {}
        for ct in valid_cts:
            ct_mask = adata.obs[ct_col] == ct
            if ct_mask.sum() < min_cells_per_ct:
                continue
            
            ad_ct_gene = adata_gene[ct_mask]
            counts = sum_over_axis0(ad_ct_gene)
            total = counts.sum()
            
            if total == 0:
                continue
            
            # Calculate fraction for each prefix type
            prefix_fracs = {}
            for prefix, prefix_mask in prefix_groups.items():
                prefix_counts = counts[prefix_mask].sum()
                prefix_fracs[prefix] = prefix_counts / total if total > 0 else 0.0
            
            # Find dominant prefix (highest fraction)
            if prefix_fracs:
                dominant_prefix = max(prefix_fracs.items(), key=lambda x: x[1])[0]
                dominant_frac = prefix_fracs[dominant_prefix]
                ct_results[ct] = {
                    'dominant_prefix': dominant_prefix,
                    'dominant_frac': dominant_frac,
                    'prefix_fracs': prefix_fracs.copy()
                }
        
        if len(ct_results) < 2:  # Need at least 2 cell types
            continue
        
        # Identify switching pattern
        # Count how many CTs use each prefix as dominant
        prefix_counts = {}
        for ct, result in ct_results.items():
            dom_prefix = result['dominant_prefix']
            prefix_counts[dom_prefix] = prefix_counts.get(dom_prefix, 0) + 1
        
        # Find the most common dominant prefix (majority pattern)
        if not prefix_counts:
            continue
        
        majority_prefix = max(prefix_counts.items(), key=lambda x: x[1])[0]
        majority_count = prefix_counts[majority_prefix]
        total_cts = len(ct_results)
        
        # Find CTs that use a different prefix
        switching_cts = []
        switching_prefixes = set()
        for ct, result in ct_results.items():
            if result['dominant_prefix'] != majority_prefix:
                # Check if the switching is significant (above threshold)
                if result['dominant_frac'] >= min_expression_frac:
                    switching_cts.append(ct)
                    switching_prefixes.add(result['dominant_prefix'])
        
        # Check if switching is significant
        switching_frac = len(switching_cts) / total_cts if total_cts > 0 else 0
        
        if (len(switching_cts) > 0 and 
            switching_frac >= min_switching_frac and
            len(switching_prefixes) > 0):
            
            switching_prefix = list(switching_prefixes)[0] if len(switching_prefixes) == 1 else ','.join(sorted(switching_prefixes))
            dominant_cts = [ct for ct, r in ct_results.items() 
                          if r['dominant_prefix'] == majority_prefix]
            
            switching_results.append({
                'gene': gene,
                'dominant_prefix': majority_prefix,
                'switching_prefix': switching_prefix,
                'n_cell_types': total_cts,
                'n_switching_cts': len(switching_cts),
                'switching_cts': ','.join(sorted(switching_cts)),
                'dominant_cts': ','.join(sorted(dominant_cts)),
                'switching_frac': switching_frac,
                'majority_count': majority_count
            })
    
    if len(switching_results) == 0:
        if verbose:
            print("No genes with isoform switching detected.")
        return pd.DataFrame()
    
    switching_df = pd.DataFrame(switching_results)
    switching_df = switching_df.sort_values('switching_frac', ascending=False)
    
    if verbose:
        print(f"\nFound {len(switching_df)} genes with isoform switching")
    
    return switching_df


In [8]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

def plot_isoform_switching(
    adata,
    gene,
    switching_df,
    ct_col='gpt-3.5-turbo_simplified_ai_cell_type',
    transcript_id_col='transcript_id',
    layer=None,
    figsize=(14, 8),
    prefix_patterns=None
):
    """
    Plot isoform switching pattern for a specific gene.
    Shows the fraction of expression from each prefix type (e.g., ENST vs PB) per cell type.
    """
    
    # Get switching info for this gene
    gene_info = switching_df[switching_df['gene'] == gene]
    if len(gene_info) == 0:
        raise ValueError(f"Gene {gene} not found in switching_df")
    
    gene_info = gene_info.iloc[0]
    dominant_prefix = gene_info['dominant_prefix']
    switching_prefix = gene_info['switching_prefix'].split(',')
    
    # Auto-detect prefixes if needed
    if prefix_patterns is None:
        all_prefixes = [dominant_prefix] + switching_prefix
        prefix_patterns = sorted(set(all_prefixes))
    
    # Helper functions
    def get_mat(A):
        mat = A.layers[layer] if layer is not None else A.X
        return mat
    
    def sum_over_axis0(A):
        mat = get_mat(A)
        if hasattr(mat, 'sum'):
            result = mat.sum(axis=0)
            if hasattr(result, 'A1'):
                return result.A1
            return np.asarray(result).ravel()
        return np.asarray(mat.sum(axis=0)).ravel()
    
    # Subset to gene
    gene_mask = adata.var['gene_name'] == gene
    adata_gene = adata[:, gene_mask].copy()
    transcript_ids = adata_gene.var[transcript_id_col].astype(str).values
    
    # Get cell types
    cell_types = sorted(adata.obs[ct_col].unique())
    
    # Calculate prefix fractions per cell type
    ct_data = []
    for ct in cell_types:
        ct_mask = adata.obs[ct_col] == ct
        if ct_mask.sum() == 0:
            continue
        
        ad_ct_gene = adata_gene[ct_mask]
        counts = sum_over_axis0(ad_ct_gene)
        total = counts.sum()
        
        if total == 0:
            continue
        
        # Group by prefix
        prefix_fracs = {}
        for prefix in prefix_patterns:
            prefix_mask = np.array([tx_id.startswith(prefix) for tx_id in transcript_ids])
            if prefix_mask.sum() > 0:
                prefix_counts = counts[prefix_mask].sum()
                prefix_fracs[prefix] = prefix_counts / total
            else:
                prefix_fracs[prefix] = 0.0
        
        ct_data.append({
            'cell_type': ct,
            **prefix_fracs
        })
    
    if len(ct_data) == 0:
        raise ValueError(f"No expression found for gene {gene}")
    
    df = pd.DataFrame(ct_data)
    
    # Sort by dominant prefix fraction (descending)
    df = df.sort_values(dominant_prefix, ascending=False)
    
    # Create stacked bar plot
    fig, ax = plt.subplots(figsize=figsize)
    
    x = np.arange(len(df))
    bottom = np.zeros(len(df))
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(prefix_patterns)))
    prefix_colors = {prefix: colors[i] for i, prefix in enumerate(prefix_patterns)}
    
    for prefix in prefix_patterns:
        values = df[prefix].values
        ax.bar(x, values, bottom=bottom, label=prefix, color=prefix_colors[prefix], alpha=0.8)
        bottom += values
    
    # Highlight switching cell types
    switching_cts = gene_info['switching_cts'].split(',')
    for i, ct in enumerate(df['cell_type']):
        if ct in switching_cts:
            ax.axvline(i, color='red', linestyle='--', alpha=0.5, linewidth=2)
    
    ax.set_xlabel('Cell Type', fontsize=12)
    ax.set_ylabel('Fraction of Expression', fontsize=12)
    ax.set_title(f'Isoform Switching: {gene}\n'
                 f'Dominant: {dominant_prefix}, Switching: {", ".join(switching_prefix)}', 
                 fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(df['cell_type'], rotation=45, ha='right', fontsize=10)
    ax.legend(title='Transcript Type', fontsize=10)
    ax.set_ylim(0, 1.05)
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    return fig


In [None]:
import numpy as np
import pandas as pd
# Example usage: Detect isoform switching
# This will analyze all genes and identify those with switching patterns

switching_results = detect_isoform_switching(
    adata,
    ct_col='gpt-3.5-turbo_simplified_ai_cell_type',
    transcript_id_col='transcript_id',
    layer=None,
    min_expression_frac=0.5,  # At least 50% of expression from a prefix type
    min_switching_frac=0.2,    # At least 20% of cell types must show switching
    min_cells_per_ct=10,       # Minimum cells per cell type
    prefix_patterns=['ENST', 'PB'],  # Or None to auto-detect
    verbose=True
)

# Display results
print(f"\nFound {len(switching_results)} genes with isoform switching")
if len(switching_results) > 0:
    print("\nTop switching genes:")
    print(switching_results.head(10))


Analyzing 6601 genes...
Using 1 cell types with >= 10 cells
  Processed 100/6601 genes...
  Processed 200/6601 genes...
  Processed 300/6601 genes...
  Processed 400/6601 genes...
  Processed 500/6601 genes...
  Processed 600/6601 genes...
  Processed 700/6601 genes...
  Processed 800/6601 genes...
  Processed 900/6601 genes...
  Processed 1000/6601 genes...
  Processed 1100/6601 genes...
  Processed 1200/6601 genes...
  Processed 1300/6601 genes...
  Processed 1400/6601 genes...
  Processed 1500/6601 genes...
  Processed 1600/6601 genes...
  Processed 1700/6601 genes...
  Processed 1800/6601 genes...
  Processed 1900/6601 genes...
  Processed 2000/6601 genes...
  Processed 2100/6601 genes...
  Processed 2200/6601 genes...
  Processed 2300/6601 genes...
  Processed 2400/6601 genes...
  Processed 2500/6601 genes...
  Processed 2600/6601 genes...
  Processed 2700/6601 genes...
  Processed 2800/6601 genes...
  Processed 2900/6601 genes...
  Processed 3000/6601 genes...
  Processed 3100/66

NameError: name 'pd' is not defined

In [9]:
import numpy as np
import pandas as pd

def detect_isoform_switching_2(
    adata,
    ct_col='gpt-3.5-turbo_simplified_ai_cell_type',
    transcript_id_col='transcript_id',
    layer=None,
    min_expression_frac=0.5,    # min fraction to call a prefix/isoform "dominant"
    min_switching_frac=0.2,     # min fraction of CTs that must switch
    min_cells_per_ct=10,        # min cells per CT to include
    prefix_patterns=None,       # e.g. ['ENST', 'PB']; if None, auto-detect
    min_baseline_prop=0.5,      # min fraction of CTs that must follow majority prefix strongly
    verbose=True,
):
    """
    Detect genes with transcript-family (prefix) switching across cell types,
    and summarize whether a specific non-majority isoform repeatedly drives switching.

    Parameters
    ----------
    adata : AnnData
        Annotated data object with cells in .obs and transcripts in .var
    ct_col : str
        Column in adata.obs containing cell type labels
    transcript_id_col : str
        Column in adata.var containing transcript IDs (e.g. ENST..., PB...)
    layer : str or None
        Layer to use for counts (None uses adata.X)
    min_expression_frac : float
        Minimum fraction of gene expression from a prefix (or isoform) to be
        considered "dominant" in a given cell type
    min_switching_frac : float
        Minimum fraction of cell types that must show switching for a gene
    min_cells_per_ct : int
        Minimum number of cells required in a cell type to include it
    prefix_patterns : list of str or None
        List of prefix patterns to group isoforms by (e.g. ['ENST','PB']).
        If None, auto-detects common prefixes from transcript IDs.
    min_baseline_prop : float
        Minimum fraction of cell types that must follow the majority prefix
        with strong dominance to call a gene "switching"
    verbose : bool
        If True, print progress info

    Returns
    -------
    switching_df : pandas.DataFrame
        One row per gene classified as switching, with columns:
          - gene
          - dominant_prefix         (majority/baseline prefix across CTs)
          - switching_prefix        (prefix(es) used in switching CTs)
          - n_cell_types            (CTs with enough data)
          - n_switching_cts
          - switching_cts           (comma-separated CT names)
          - dominant_cts            (CTs where majority prefix is dominant)
          - baseline_cts            (CTs where majority prefix is dominant AND strong)
          - switching_frac          (n_switching_cts / n_cell_types)
          - baseline_prop           (len(baseline_cts) / n_cell_types)
          - majority_count          (how many CTs majority prefix is dominant)
          - main_switch_isoform     (isoform ID that most often dominates in switching CTs)
          - n_ct_main_switch_isoform(# switching CTs where that isoform is top)
          - isoform_switch_frac     (n_ct_main_switch_isoform / n_switching_cts)
    """

    # -------- helpers for matrix access --------
    def get_mat(A):
        return A.layers[layer] if layer is not None else A.X

    def sum_over_axis0(A):
        mat = get_mat(A)
        result = mat.sum(axis=0)
        # dense or sparse
        if hasattr(result, "A1"):
            return result.A1
        return np.asarray(result).ravel()

    # -------- sanity checks --------
    if transcript_id_col not in adata.var.columns:
        raise ValueError(f"adata.var must contain '{transcript_id_col}'")
    if "gene_name" not in adata.var.columns:
        raise ValueError("adata.var must contain 'gene_name'")
    if ct_col not in adata.obs.columns:
        raise ValueError(f"adata.obs must contain '{ct_col}'")

    # -------- auto-detect prefix patterns if needed --------
    if prefix_patterns is None:
        tx_ids_all = adata.var[transcript_id_col].astype(str)
        prefixes = set()
        for tx_id in tx_ids_all.unique()[:1000]:  # sample first 1000
            if "." in tx_id:
                prefix = tx_id.split(".")[0]
            elif "_" in tx_id:
                prefix = tx_id.split("_")[0]
            else:
                prefix = tx_id[:4]  # first 4 characters
            if len(prefix) >= 2:
                prefixes.add(prefix)
        prefix_patterns = sorted(prefixes)
        if verbose:
            print(f"Auto-detected prefix patterns: {prefix_patterns}")

    # -------- gene + CT sets --------
    all_genes = adata.var["gene_name"].unique()
    if verbose:
        print(f"Analyzing {len(all_genes)} genes...")

    ct_counts = adata.obs[ct_col].value_counts()
    valid_cts = ct_counts[ct_counts >= min_cells_per_ct].index.tolist()
    if verbose:
        print(f"Using {len(valid_cts)} cell types with ≥ {min_cells_per_ct} cells")

    switching_results = []

    # -------- main loop over genes --------
    for gene_idx, gene in enumerate(all_genes):
        if verbose and (gene_idx + 1) % 100 == 0:
            print(f"  Processed {gene_idx + 1}/{len(all_genes)} genes...")

        # transcripts for this gene
        gene_mask = adata.var["gene_name"] == gene
        if gene_mask.sum() == 0:
            continue

        # subset AnnData to this gene (columns = isoforms for gene)
        adata_gene = adata[:, gene_mask]
        transcript_ids = adata_gene.var[transcript_id_col].astype(str).values

        # map each isoform to a prefix (first matching pattern)
        tx_prefix = np.array([
            next((p for p in prefix_patterns if tx_id.startswith(p)), None)
            for tx_id in transcript_ids
        ])

        # group isoforms by prefix (per gene)
        prefix_groups = {}
        for prefix in prefix_patterns:
            prefix_mask = (tx_prefix == prefix)
            if prefix_mask.sum() > 0:
                prefix_groups[prefix] = prefix_mask

        # need at least 2 prefix families to talk about switching
        if len(prefix_groups) < 2:
            continue

        # ---- per-CT results for this gene ----
        ct_results = {}  # ct -> dict with dominant_prefix, dominant_frac, iso_fracs, etc.

        for ct in valid_cts:
            ct_mask = (adata.obs[ct_col] == ct)
            if ct_mask.sum() < min_cells_per_ct:
                continue

            ad_ct_gene = adata_gene[ct_mask]
            counts = sum_over_axis0(ad_ct_gene)  # length = #isoforms in gene
            total = counts.sum()

            if total <= 0:
                continue

            iso_fracs = counts / total  # fraction of gene counts per isoform

            # prefix-level fractions
            prefix_fracs = {}
            for prefix, prefix_mask in prefix_groups.items():
                prefix_counts = counts[prefix_mask].sum()
                prefix_fracs[prefix] = (prefix_counts / total) if total > 0 else 0.0

            if not prefix_fracs:
                continue

            # dominant prefix for this gene × CT
            dominant_prefix = max(prefix_fracs.items(), key=lambda x: x[1])[0]
            dominant_frac = prefix_fracs[dominant_prefix]

            ct_results[ct] = {
                "dominant_prefix": dominant_prefix,
                "dominant_frac": dominant_frac,
                "prefix_fracs": prefix_fracs.copy(),
                "iso_fracs": iso_fracs,
                "iso_counts": counts,
            }

        # need ≥ 2 CTs with data for this gene
        if len(ct_results) < 2:
            continue

        # ---- majority (baseline) prefix across CTs ----
        prefix_counts = {}
        for ct, r in ct_results.items():
            p = r["dominant_prefix"]
            prefix_counts[p] = prefix_counts.get(p, 0) + 1

        if not prefix_counts:
            continue

        majority_prefix = max(prefix_counts.items(), key=lambda x: x[1])[0]
        majority_count = prefix_counts[majority_prefix]
        total_cts = len(ct_results)

        # ---- baseline CTs: majority prefix, strongly dominant ----
        baseline_cts = [
            ct for ct, r in ct_results.items()
            if (
                r["dominant_prefix"] == majority_prefix
                and r["dominant_frac"] >= min_expression_frac
            )
        ]
        baseline_prop = len(baseline_cts) / total_cts if total_cts > 0 else 0.0

        # ---- switching CTs: non-majority prefix, strongly dominant ----
        switching_cts = []
        switching_prefixes = set()
        for ct, r in ct_results.items():
            if r["dominant_prefix"] != majority_prefix and r["dominant_frac"] >= min_expression_frac:
                switching_cts.append(ct)
                switching_prefixes.add(r["dominant_prefix"])

        switching_frac = len(switching_cts) / total_cts if total_cts > 0 else 0.0

        # ---- NEW: main switching isoform across switching CTs ----
        main_switch_isoform = None
        n_ct_main_switch_isoform = 0
        isoform_switch_frac = 0.0

        if len(switching_cts) > 0:
            non_major_mask = (tx_prefix != majority_prefix)
            switch_isoforms_per_ct = {}  # ct -> isoform_id
            isoform_ct_counts = {}       # isoform_id -> #switching CTs top

            if non_major_mask.any():
                for ct in switching_cts:
                    r = ct_results[ct]
                    iso_fracs_ct = r["iso_fracs"]

                    fracs_non_major = iso_fracs_ct[non_major_mask]
                    if fracs_non_major.sum() == 0:
                        continue

                    rel_idx = np.argmax(fracs_non_major)
                    abs_idx = np.where(non_major_mask)[0][rel_idx]
                    best_isoform_id = transcript_ids[abs_idx]

                    switch_isoforms_per_ct[ct] = best_isoform_id
                    isoform_ct_counts[best_isoform_id] = isoform_ct_counts.get(best_isoform_id, 0) + 1

            if len(isoform_ct_counts) > 0:
                main_switch_isoform, n_ct_main_switch_isoform = max(
                    isoform_ct_counts.items(), key=lambda x: x[1]
                )
                isoform_switch_frac = n_ct_main_switch_isoform / len(switching_cts)

        # ---- gene-level criteria: must have both groups + sufficient fractions ----
        if (
            len(switching_cts) > 0                        # at least one switching CT
            and len(baseline_cts) > 0                     # at least one baseline CT
            and switching_frac >= min_switching_frac      # enough CTs switch
            and baseline_prop >= min_baseline_prop        # enough CTs baseline
            and len(switching_prefixes) > 0               # switching prefix defined
        ):
            switching_prefix = (
                list(switching_prefixes)[0]
                if len(switching_prefixes) == 1
                else ",".join(sorted(switching_prefixes))
            )
            dominant_cts = [
                ct for ct, r in ct_results.items()
                if r["dominant_prefix"] == majority_prefix
            ]

            switching_results.append({
                "gene": gene,
                "dominant_prefix": majority_prefix,
                "switching_prefix": switching_prefix,
                "n_cell_types": total_cts,
                "n_switching_cts": len(switching_cts),
                "switching_cts": ",".join(sorted(switching_cts)),
                "dominant_cts": ",".join(sorted(dominant_cts)),
                "baseline_cts": ",".join(sorted(baseline_cts)),
                "switching_frac": switching_frac,
                "baseline_prop": baseline_prop,
                "majority_count": majority_count,
                "main_switch_isoform": main_switch_isoform,
                "n_ct_main_switch_isoform": n_ct_main_switch_isoform,
                "isoform_switch_frac": isoform_switch_frac,
            })

    if len(switching_results) == 0:
        if verbose:
            print("No genes with isoform switching detected.")
        return pd.DataFrame()

    switching_df = pd.DataFrame(switching_results)
    switching_df = switching_df.sort_values(
        ["switching_frac", "isoform_switch_frac"], ascending=False
    ).reset_index(drop=True)

    if verbose:
        print(f"\nFound {len(switching_df)} genes with isoform switching")

    return switching_df


In [10]:
switching_df = detect_isoform_switching_2(
    adata,
    ct_col='gpt-3.5-turbo_simplified_ai_cell_type',
    transcript_id_col='transcript_id',   # or whatever your column is
    layer=None,                          # or 'counts' / 'X_raw' etc.
    min_expression_frac=0.2,
    min_switching_frac=0.1,
    min_cells_per_ct=2,
    prefix_patterns=['ENST', 'PB'],      # or let it auto-detect
    min_baseline_prop=0.1,
    verbose=True,
)

switching_df.head()


Analyzing 6601 genes...
Using 1 cell types with ≥ 2 cells
  Processed 100/6601 genes...
  Processed 200/6601 genes...
  Processed 300/6601 genes...
  Processed 400/6601 genes...
  Processed 500/6601 genes...
  Processed 600/6601 genes...
  Processed 700/6601 genes...
  Processed 800/6601 genes...
  Processed 900/6601 genes...
  Processed 1000/6601 genes...
  Processed 1100/6601 genes...
  Processed 1200/6601 genes...
  Processed 1300/6601 genes...
  Processed 1400/6601 genes...
  Processed 1500/6601 genes...
  Processed 1600/6601 genes...
  Processed 1700/6601 genes...
  Processed 1800/6601 genes...
  Processed 1900/6601 genes...
  Processed 2000/6601 genes...
  Processed 2100/6601 genes...
  Processed 2200/6601 genes...
  Processed 2300/6601 genes...
  Processed 2400/6601 genes...
  Processed 2500/6601 genes...
  Processed 2600/6601 genes...
  Processed 2700/6601 genes...
  Processed 2800/6601 genes...
  Processed 2900/6601 genes...
  Processed 3000/6601 genes...
  Processed 3100/6601