In [3]:
import anndata as ad

adata = ad.read_h5ad("all_tissues_raw_counts_transcripts_bc_anndata.h5ad")

In [4]:
# Inspect the transcript_id column
print("Checking transcript_id column in adata.var:")
print(f"\nDoes 'transcript_id' exist in adata.var.columns? {'transcript_id' in adata.var.columns}")
print(f"\nVar columns: {list(adata.var.columns)}")

if 'transcript_id' in adata.var.columns:
    print(f"\nFirst 10 transcript IDs:")
    print(adata.var['transcript_id'].head(10))
    print(f"\nTotal number of transcripts: {len(adata.var['transcript_id'])}")
    print(f"\nUnique transcript IDs: {adata.var['transcript_id'].nunique()}")
    print(f"\nSample of transcript IDs:")
    print(adata.var['transcript_id'].sample(min(20, len(adata.var))).tolist())
else:
    print("\n'transcript_id' column not found in adata.var!")
    print("Available columns:", list(adata.var.columns))


Checking transcript_id column in adata.var:

Does 'transcript_id' exist in adata.var.columns? True

Var columns: ['transcript_id', 'gene_name']

First 10 transcript IDs:
ENST00000000233.10    ENST00000000233.10
ENST00000000412.8      ENST00000000412.8
ENST00000000442.11    ENST00000000442.11
ENST00000001008.6      ENST00000001008.6
ENST00000001146.7      ENST00000001146.7
ENST00000002125.9      ENST00000002125.9
ENST00000002165.11    ENST00000002165.11
ENST00000002501.11    ENST00000002501.11
ENST00000002596.6      ENST00000002596.6
ENST00000002829.8      ENST00000002829.8
Name: transcript_id, dtype: category
Categories (5436103, object): ['ENST00000000233.10', 'ENST00000000412.8', 'ENST00000000442.11', 'ENST00000001008.6', ..., 'PB.853226.65', 'PB.853226.74', 'PB.853226.76', 'PB.853247.4']

Total number of transcripts: 5498247

Unique transcript IDs: 5436103

Sample of transcript IDs:
['PB.51985.142', 'PB.5538.9', 'PB.156910.261', 'PB.214855.16', 'PB.240635.550', 'PB.183806.24', 'PB.1

In [3]:
# Create a subset of the h5ad file for testing
# Choose one of the options below:

# Option 1: Random sample of cells (recommended for testing)
import numpy as np
np.random.seed(42)  # For reproducibility

n_cells_subset = 1000  # Adjust this number as needed
if adata.n_obs > n_cells_subset:
    # Randomly sample cells
    cell_indices = np.random.choice(adata.n_obs, size=n_cells_subset, replace=False)
    adata_subset = adata[cell_indices].copy()
    print(f"Created random subset: {adata_subset.shape} (from {adata.shape})")
else:
    adata_subset = adata.copy()
    print(f"Data already smaller than requested, using full dataset: {adata_subset.shape}")

# Option 2: First N cells (uncomment to use instead)
# n_cells_subset = 10000
# adata_subset = adata[:n_cells_subset].copy()
# print(f"Created subset (first {n_cells_subset} cells): {adata_subset.shape}")

# Option 3: Specific samples/tissues (uncomment to use instead)
# # Filter to specific samples
# specific_samples = ['TSP33_salivary-gland_5_kinnexsc', 'TSP33_liver_20_kinnexsc', 'TSP33_lung_10_kinnexsc']
# mask = adata.obs['Sample'].isin(specific_samples)
# adata_subset = adata[mask].copy()
# print(f"Created subset from specific samples: {adata_subset.shape}")

# Option 4: Random sample of cells AND subset of genes (to reduce size further)
# n_cells_subset = 5000
# n_genes_subset = 50000  # Keep top N most expressed genes
# cell_indices = np.random.choice(adata.n_obs, size=min(n_cells_subset, adata.n_obs), replace=False)
# # Get top expressed genes
# gene_counts = np.asarray(adata.X.sum(axis=0)).ravel()
# top_gene_indices = np.argsort(gene_counts)[-n_genes_subset:]
# adata_subset = adata[np.sort(cell_indices), top_gene_indices].copy()
# print(f"Created subset: {adata_subset.shape} ({n_cells_subset} cells, {n_genes_subset} top genes)")

# Save the subset to a new file
output_file = "all_tissues_raw_counts_transcripts_bc_anndata_subset.h5ad"
adata_subset.write_h5ad(output_file)
print(f"\nSaved subset to: {output_file}")
print(f"File size: {output_file}")  # You can check file size with: ls -lh {output_file}


Created random subset: (1000, 5498247) (from (86002, 5498247))

Saved subset to: all_tissues_raw_counts_transcripts_bc_anndata_subset.h5ad
File size: all_tissues_raw_counts_transcripts_bc_anndata_subset.h5ad


In [4]:
adata = ad.read_h5ad("all_tissues_raw_counts_transcripts_bc_anndata_subset.h5ad")

In [5]:
adata

AnnData object with n_obs × n_vars = 1000 × 5498247
    obs: 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'Sample', 'n_counts'
    var: 'transcript_id', 'gene_name'

In [6]:
# Just the first few entries of the Sample column
adata.obs["Sample"].head()


CAGAATCACGGATCCC-1_TSP28_ovary-left_29_kinnexsc    TSP28_ovary-left_29_kinnexsc
GACAGGTGATGTCCGG-1_TSP33_lung_10_kinnexsc                TSP33_lung_10_kinnexsc
CCGGCATGAGGCAATG-1_TSP28_ovary-left_29_kinnexsc    TSP28_ovary-left_29_kinnexsc
ATTGCTACTGACCCGT-1_TSP33_ear_3_kinnexsc                    TSP33_ear_3_kinnexsc
GGAAGTGACGAGGTTG-1_TSP33_bone-marrow_4_kinnexsc    TSP33_bone-marrow_4_kinnexsc
Name: Sample, dtype: category
Categories (23, object): ['TSP21_uterus-endometrium_31_kinnexsc', 'TSP28_ovary-left-1_30_kinnexsc', 'TSP28_ovary-left_29_kinnexsc', 'TSP31_lymph-node_27_kinnexsc', ..., 'TSP33_thymus_17_kinnexsc', 'TSP33_tongue_13_kinnexsc', 'TSP33_trachea_25_kinnexsc', 'TSP33_vasculature-aor_22_kinnexsc']

In [7]:
import pandas as pd

# Option 1: Work with sparse matrix directly (memory efficient)
# Access the sparse matrix without converting to dense
print(f"Data shape: {adata.shape}")
print(f"Data type: {type(adata.X)}")
print(f"\nFirst few rows (sparse format):")
print(adata.X[:5, :10].toarray() if hasattr(adata.X, 'toarray') else adata.X[:5, :10])

# Option 2: Convert only a small subset to DataFrame (if needed)
# df_sample = adata[:1000, :1000].to_df()  # First 1000 cells and 1000 transcripts
# df_sample.head()

# Option 3: Access specific rows/columns without full conversion
# df_subset = pd.DataFrame(
#     adata[:100, :100].X.toarray() if hasattr(adata.X, 'toarray') else adata[:100, :100].X,
#     index=adata.obs_names[:100],
#     columns=adata.var_names[:100]
# )
# df_subset.head()


Data shape: (1000, 5498247)
Data type: <class 'scipy.sparse._csr.csr_matrix'>

First few rows (sparse format):
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [2. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]


In [8]:
# Filter adata to only TSP33 samples and add Tissue column
print(f"Original adata shape: {adata.shape}")

# Filter to only TSP33 samples
tsp33_mask = adata.obs["Sample"].astype(str).str.startswith("TSP33")
adata = adata[tsp33_mask].copy()

print(f"Filtered adata shape (TSP33 only): {adata.shape}")

# Extract tissue name from Sample column
# Pattern: TSP33_tissue-name_##_kinnexsc
def extract_tissue(sample_name):
    """Extract tissue name from Sample column."""
    parts = str(sample_name).split('_')
    if len(parts) >= 2:
        return parts[1]  # The tissue is the second part
    else:
        return str(sample_name)  # Fallback

# Add Tissue column to adata.obs
adata.obs['Tissue'] = adata.obs['Sample'].apply(extract_tissue)

# Verify the Tissue column was added
print(f"\nUnique tissues: {adata.obs['Tissue'].unique()}")
print(f"\nSample of data:")
print(adata.obs[['Sample', 'Tissue']].head(10))


Original adata shape: (1000, 5498247)
Filtered adata shape (TSP33 only): (717, 5498247)

Unique tissues: ['lung' 'ear' 'bone-marrow' 'thymus' 'bladder' 'pancreas' 'testis'
 'tongue' 'small-intestine' 'fat' 'trachea' 'blood' 'salivary-gland'
 'muscle' 'ascending-colon' 'liver' 'vasculature-aor']

Sample of data:
                                                                       Sample  \
GACAGGTGATGTCCGG-1_TSP33_lung_10_kinnexsc              TSP33_lung_10_kinnexsc   
ATTGCTACTGACCCGT-1_TSP33_ear_3_kinnexsc                  TSP33_ear_3_kinnexsc   
GGAAGTGACGAGGTTG-1_TSP33_bone-marrow_4_kinnexsc  TSP33_bone-marrow_4_kinnexsc   
TGGAACACTCCATTCC-1_TSP33_thymus_17_kinnexsc          TSP33_thymus_17_kinnexsc   
TATGCAGACAAGGTAT-1_TSP33_bladder_15_kinnexsc        TSP33_bladder_15_kinnexsc   
CGAAACGGAGAAGGTC-1_TSP33_pancreas_12_kinnexsc      TSP33_pancreas_12_kinnexsc   
AGCTAGACTGCGTGAA-1_TSP33_lung_10_kinnexsc              TSP33_lung_10_kinnexsc   
CTAATTCACCGGATAG-1_TSP33_testis_16_kinn

In [9]:
import anndict as adt
adata_dict = adt.build_adata_dict(adata, ['Tissue'])

In [None]:
# Configure LLM backend


# These are just the standard preprocessing steps (starting from raw counts)
adt.wrappers.normalize_adata_dict(adata_dict) # Normalize
adt.wrappers.log_transform_adata_dict(adata_dict) # Log
adt.wrappers.set_high_variance_genes_adata_dict(adata_dict, n_top_genes=2000, subset=False) # Set highly variable genes
adt.wrappers.scale_adata_dict(adata_dict) # Scale
adt.wrappers.pca_adata_dict(adata_dict, n_comps=50, mask_var='highly_variable') #PCA
adt.wrappers.neighbors_adata_dict(adata_dict) # Neighborhood graph
adt.wrappers.calculate_umap_adata_dict(adata_dict) # UMAP
adt.wrappers.leiden_adata_dict(adata_dict, resolution=0.4) # Cluster

# Run differential expression analysis independently on each anndata in adata_dict
adt.wrappers.rank_genes_groups_adata_dict(adata_dict, groupby='leiden', use_raw=False)

# Get the model name directly from the LLM config (just for naming the column in .obs)
model = adt.get_llm_config()['model']

# Use the LLM to annotate celltypes based on the 'leiden' column, pass tissue information from 'tissue' column.
# The column containing the new labels will be in adata.obs[label_column] for each adata in adata_dict
new_label_column = f'{model}_ai_cell_type'
label_results = adt.wrappers.ai_annotate_cell_type_adata_dict(adata_dict, groupby='leiden', n_top_genes=10, new_label_column=new_label_column, tissue_of_origin_col='tissue')

# These labels may have some redundancy, so merge them with the LLM
ai_label_column = f'{model}_simplified_ai_cell_type'
simplified_mappings = adt.wrappers.simplify_obs_column_adata_dict(adata_dict, new_label_column, ai_label_column, simplification_level='redundancy-removed')



Failed to process ('liver',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=6 with svd_solver='arpack'
Failed to process ('bladder',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=33 with svd_solver='arpack'


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_sub[k] = df_sub[k].cat.remove_unused_categories()


Failed to process ('salivary-gland',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=29 with svd_solver='arpack'Failed to process ('small-intestine',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=25 with svd_solver='arpack'

Failed to process ('bone-marrow',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=44 with svd_solver='arpack'
Failed to process ('testis',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=29 with svd_solver='arpack'
Failed to process ('tongue',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=7 with svd_solver='arpack'
Failed to process ('vasculature-aor',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=3 with svd_solver='arpack'Failed to process ('ascending-colon',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=6 wi

         Falling back to preprocessing with `sc.pp.pca` and default params.


Failed to process ('ascending-colon',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=6 with svd_solver='arpack'
Failed to process ('bladder',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=33 with svd_solver='arpack'


         Falling back to preprocessing with `sc.pp.pca` and default params.
  from .autonotebook import tqdm as notebook_tqdm


Failed to process ('bone-marrow',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=44 with svd_solver='arpack'


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


Failed to process ('liver',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=6 with svd_solver='arpack'
Failed to process ('salivary-gland',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=29 with svd_solver='arpack'
Failed to process ('small-intestine',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=25 with svd_solver='arpack'
Failed to process ('testis',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=29 with svd_solver='arpack'
Failed to process ('tongue',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=7 with svd_solver='arpack'


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


Failed to process ('vasculature-aor',) after 0 attempts: n_components=50 must be between 1 and min(n_samples, n_features)=3 with svd_solver='arpack'
Failed to process ('ascending-colon',) after 0 attempts: Did not find .uns['neighbors']. Run `sc.pp.neighbors` first.
Failed to process ('bladder',) after 0 attempts: Did not find .uns['neighbors']. Run `sc.pp.neighbors` first.
Failed to process ('bone-marrow',) after 0 attempts: Did not find .uns['neighbors']. Run `sc.pp.neighbors` first.
Failed to process ('liver',) after 0 attempts: Did not find .uns['neighbors']. Run `sc.pp.neighbors` first.
Failed to process ('salivary-gland',) after 0 attempts: Did not find .uns['neighbors']. Run `sc.pp.neighbors` first.
Failed to process ('small-intestine',) after 0 attempts: Did not find .uns['neighbors']. Run `sc.pp.neighbors` first.
Failed to process ('testis',) after 0 attempts: Did not find .uns['neighbors']. Run `sc.pp.neighbors` first.
Failed to process ('tongue',) after 0 attempts: Did not f

In [15]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

def _distinct_colors(cmap_name: str, n: int):
    """
    Return n visually distinct colors from the given colormap.
    - Uses discrete palette entries if available (e.g., tab20).
    - Otherwise samples evenly from the continuous colormap.
    """
    cmap = plt.get_cmap(cmap_name)
    # If it's a ListedColormap, try to use its discrete list
    base = getattr(cmap, "colors", None)
    if isinstance(base, (list, tuple)) and len(base) > 0:
        if n <= len(base):
            # pick n roughly evenly-spaced indices across the palette
            idxs = np.linspace(0, len(base) - 1, n)
            idxs = np.round(idxs).astype(int)
            # ensure we have exactly n (fallback if rounding collided)
            if len(np.unique(idxs)) == n:
                return [base[i] for i in idxs]
        # if n > len(base) (or rounding collided), fall back to continuous sampling
    # Sample the continuous map evenly in [0,1)
    return [cmap(t) for t in np.linspace(0, 1, n, endpoint=False)]

def plot_count_isoform_usage(
    adata,
    gene,
    cell_types=None,   # <- optional: auto-picks CTs expressing the gene
    ct_col='claude-3-5-sonnet-20240620_simplified_ai_cell_type',
    transcript_id_col='transcript_id',
    min_frac=0.1,
    figsize=(36, 18),
    cmap='tab20',
    bar_width=0.8,
    title_fs=16,
    label_fs=10,
    tick_fs=12,
    legend_fs=12,
    *,
    layer=None,                    # e.g. "counts" if your raw counts are in a layer
    annotate_cells=True,
    annotate_counts=True,
    cells_label='expressing',      # 'expressing' (default) or 'total'
    min_frac_label=0.03,           # don't label segments smaller than this fraction
    counts_label_fs=10,
    cells_label_fs=10,
    return_table=False             # return a DataFrame of what was plotted
):
    """
    One stacked vertical bar per cell type showing isoform usage of `gene`.
    Adds:
      - # cells above each bar (default = cells with >0 gene counts in that CT)
      - raw counts on each stacked segment (hidden if < min_frac_label of bar)
    """

    # --- helpers ---
    def get_mat(A):
        return A.layers[layer] if layer is not None else A.X

    def sum_over_axis0(A):
        return np.asarray(get_mat(A).sum(axis=0)).ravel()

    def sum_over_axis1(A):
        return np.asarray(get_mat(A).sum(axis=1)).ravel()

    # sanity checks
    if transcript_id_col not in adata.var.columns:
        raise ValueError(f"adata.var must contain '{transcript_id_col}'")
    if 'gene_name' not in adata.var.columns:
        raise ValueError("adata.var must contain 'gene_name'")
    if ct_col not in adata.obs.columns:
        raise ValueError(f"adata.obs must contain '{ct_col}'")

    # find transcripts for this gene
    gene_mask = (adata.var['gene_name'] == gene)
    if gene_mask.sum() == 0:
        raise ValueError(f"Gene {gene!r} not found in adata.var['gene_name']")
    iso_all = adata.var.loc[gene_mask, transcript_id_col].values
    tx_mask = gene_mask.values

    # auto-pick cell types if not provided: those with >0 total counts for this gene
    if cell_types is None:
        total_per_cell = (get_mat(adata)[:, tx_mask].sum(axis=1))
        total_per_cell = np.asarray(total_per_cell).ravel()
        express_mask = total_per_cell > 0
        cell_types = adata.obs.loc[express_mask, ct_col].unique().tolist()

    # only those cell-types present in the data
    present = [ct for ct in cell_types if (adata.obs[ct_col] == ct).any()]
    if not present:
        raise ValueError("None of the requested cell-types are present")

    # pick the single most abundant transcript overall (for x-axis label)
    global_counts = np.asarray(get_mat(adata)[:, tx_mask].sum(axis=0)).ravel()
    most_idx = int(global_counts.argmax())
    iso_most = iso_all[most_idx]

    # compute total gene counts & fraction of iso_most per CT (to sort bars)
    total_by_ct = {}
    frac_most_by_ct = {}
    for ct in present:
        ad_ct = adata[adata.obs[ct_col] == ct]
        counts = np.asarray(get_mat(ad_ct)[:, tx_mask].sum(axis=0)).ravel()
        total = float(counts.sum())
        total_by_ct[ct] = total
        frac_most_by_ct[ct] = 0.0 if total == 0 else counts[most_idx] / total

    # drop CTs with zero expression
    present = [ct for ct in present if total_by_ct[ct] > 0]
    if not present:
        raise ValueError(f"No cells express {gene!r} in any of your chosen cell-types")

    # sort remaining CTs by fraction of iso_most (desc)
    present.sort(key=lambda ct: frac_most_by_ct[ct], reverse=True)

    # determine which isoforms get their own color (>= min_frac in at least one CT)
    shown_isos = set()
    for ct in present:
        ad_ct = adata[adata.obs[ct_col] == ct]
        counts = np.asarray(get_mat(ad_ct)[:, tx_mask].sum(axis=0)).ravel()
        if counts.sum() == 0:
            continue
        fracs = counts / counts.sum()
        order = np.argsort(fracs)[::-1]
        iso_ord = iso_all[order]
        keep = fracs[order] >= min_frac
        shown_isos.update(iso_ord[keep])
    shown_isos = sorted(shown_isos)

    # ---- DISTINCT COLORS (new) ----
    n_colors = max(1, len(shown_isos))
    base_colors = _distinct_colors(cmap, n_colors)
    iso2col = {iso: base_colors[i] for i, iso in enumerate(shown_isos)}
    other_col = (0.8, 0.8, 0.8, 1.0)
    # --------------------------------

    # plot
    fig, ax = plt.subplots(figsize=figsize)
    x = np.arange(len(present))
    legend_patches = {}
    y_max = 1.0

    # optional table rows
    table_rows = []

    for i, ct in enumerate(present):
        ad_ct = adata[adata.obs[ct_col] == ct]

        # counts per transcript for this gene & CT
        counts = np.asarray(get_mat(ad_ct)[:, tx_mask].sum(axis=0)).ravel()
        total_ct = counts.sum()
        fracs = counts / total_ct
        order = np.argsort(fracs)[::-1]
        iso_ord, fracs_ord, counts_ord = iso_all[order], fracs[order], counts[order]

        # keep majors, lump tail
        keep = fracs_ord >= min_frac
        iso_k = list(iso_ord[keep])
        frac_k = list(fracs_ord[keep])
        count_k = list(counts_ord[keep])

        tail_frac = float(fracs_ord[~keep].sum())
        tail_count = float(counts_ord[~keep].sum())
        if tail_frac > 0:
            iso_k.append('Other')
            frac_k.append(tail_frac)
            count_k.append(tail_count)

        # cells counts
        gene_counts_by_cell = sum_over_axis1(ad_ct[:, tx_mask])
        n_cells_expressing = int(np.count_nonzero(gene_counts_by_cell > 0))
        n_cells_total = int(ad_ct.n_obs)

        # stacked bar
        bottom = 0.0
        for iso, frac, cnt in zip(iso_k, frac_k, count_k):
            col = other_col if iso == 'Other' else iso2col[iso]
            ax.bar(i, frac, bottom=bottom, width=bar_width, color=col, edgecolor='white')

            # segment label: raw counts
            if annotate_counts and frac >= min_frac_label and cnt > 0:
                ax.text(
                    i, bottom + frac/2.0, f"{int(round(cnt)):,}",
                    ha='center', va='center', fontsize=counts_label_fs
                )

            # record for table
            if return_table:
                table_rows.append({
                    'gene': gene,
                    'cell_type': ct,
                    'isoform': iso,
                    'count': int(round(cnt)),
                    'fraction': float(frac),
                    'n_cells_expressing': n_cells_expressing,
                    'n_cells_total': n_cells_total
                })

            bottom += frac
            if iso not in legend_patches:
                legend_patches[iso] = Patch(color=col, label=iso)

        # cells label above the bar
        if annotate_cells:
            n_label = n_cells_total if cells_label == 'total' else n_cells_expressing
            ax.text(
                i, y_max + 0.02, f"n={n_label:,}",
                ha='center', va='bottom', fontsize=cells_label_fs, clip_on=False
            )

    # formatting
    ax.set_xticks(x)
    ax.set_xticklabels(present, rotation=45, ha='right', fontsize=tick_fs)
    ax.set_ylabel("Fraction of transcripts", fontsize=label_fs)
    ax.set_xlabel(f"Cell types sorted by fraction of {iso_most}", fontsize=label_fs)
    ax.set_title(f"Isoform usage for {gene}", fontsize=title_fs)
    ax.set_ylim(0, 1.10)  # headroom for n= labels

    ax.legend(
        handles=list(legend_patches.values()),
        bbox_to_anchor=(1.02, 1), loc='upper left',
        frameon=False, title='Transcript ID',
        fontsize=legend_fs, title_fontsize=legend_fs
    )
    plt.tight_layout()

    if return_table:
        df = pd.DataFrame(table_rows)
        df = df[['gene','cell_type','isoform','count','fraction','n_cells_expressing','n_cells_total']]
        return fig, df
    else:
        return fig

In [19]:
plot_count_isoform_usage(
    adata_combined,
    "TPM1",
    cell_types=None,   # <- optional: auto-picks CTs expressing the gene
    ct_col='claude-3-5-sonnet-20240620_simplified_ai_cell_type',
    transcript_id_col='transcript_id',
    min_frac=0.1,
    figsize=(36, 18),
    cmap='tab20',
    bar_width=0.8,
    title_fs=16,
    label_fs=10,
    tick_fs=12,
    legend_fs=12,
    layer=None,                    # e.g. "counts" if your raw counts are in a layer
    annotate_cells=True,
    annotate_counts=True,
    cells_label='expressing',      # 'expressing' (default) or 'total'
    min_frac_label=0.03,           # don't label segments smaller than this fraction
    counts_label_fs=10,
    cells_label_fs=10,
    return_table=False             # return a DataFrame of what was plotted
)

ValueError: adata.var must contain 'transcript_id'

In [14]:
adata_dict

{('ascending-colon',): AnnData object with n_obs × n_vars = 6 × 5498247
     obs: 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'Sample', 'n_counts', 'Tissue'
     var: 'transcript_id', 'gene_name', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std'
     uns: 'log1p', 'hvg', 'rank_genes_groups',
 ('bladder',): AnnData object with n_obs × n_vars = 33 × 5498247
     obs: 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'Sample', 'n_counts', 'Tissue'
     var: 'transcript_id', 'gene_name', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std'
     uns: 'log1p', 'hvg', 'rank_genes_groups',
 ('blood',): AnnData object with n_

In [20]:
# Concatenate adata_dict into a single adata object for plotting
# The plot_count_isoform_usage function expects a single AnnData object, not a dictionary
import scanpy as sc

# Convert dict values to list (adata_dict values are AnnData objects)
adata_list = list(adata_dict.values())

# Check var columns from first adata (they should all be the same)
if len(adata_list) > 0:
    first_adata = adata_list[0]
    print(f"Var columns in first adata: {list(first_adata.var.columns)}")
    print(f"Var shape: {first_adata.var.shape}")

# Concatenate all AnnData objects
# Use join='outer' to keep all features, and fill_value=0 for missing values
adata_combined = sc.concat(adata_list, join='outer', index_unique=None, fill_value=0)

# IMPORTANT: Restore var columns from the original adata
# sc.concat sometimes loses var columns, so we need to restore them
if len(adata_list) > 0:
    # All adata objects should have the same var index and columns
    # Use the var from the first one (they should all be identical)
    adata_combined.var = first_adata.var.copy()
    print(f"\nRestored var columns: {list(adata_combined.var.columns)}")

print(f"\nCombined adata shape: {adata_combined.shape}")
print(f"Available obs columns: {list(adata_combined.obs.columns)}")
print(f"Available var columns: {list(adata_combined.var.columns)}")
print(f"\nNow use 'adata_combined' instead of 'adata_dict' in the plot_count_isoform_usage call below")


Var columns in first adata: ['transcript_id', 'gene_name', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std']
Var shape: (5498247, 8)


: 