In [None]:
import anndata as ad
import matplotlib.pyplot as plt
import mudata as md
import muon
import scanpy as sc
import scvi
import seaborn as sns
import torch
import pandas as pd
import numpy as np
import json

In [None]:
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

In [None]:
sc.set_figure_params(figsize=(6, 6), frameon=False, dpi_save=500, )
sns.set_theme()
torch.set_float32_matmul_precision("high")
save_dir = './'

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

In [None]:
mdata = muon.read("/media/Lynn/data/totalVI/1st_run/mdata/mdata_leiden_dendogram.h5mu")

In [None]:
model = scvi.model.TOTALVI.load("/media/Lynn/data/totalVI/1st_run/my_model_400_epochs", mdata)

In [None]:
rna = mdata.mod['rna']
protein = mdata.mod['protein']

TOTALVI_CLUSTERS_KEY = "leiden_totalVI"
TOTALVI_LATENT_KEY = "X_totalVI"

In [None]:
de_df = model.differential_expression(
    groupby="rna:leiden_totalVI", delta=0.5, batch_correction=True
)

In [None]:
de_df

In [None]:
import pandas as pd

# Make a copy to avoid modifying the original DataFrame
df = de_df.copy()

# Remove '_protein' suffix if present
df.index = df.index.str.replace('_protein$', '', regex=True)

# Select top 5 DE genes per cluster (by proba_de or any other criterion)
top_n = 5
top_genes_per_cluster = (
    df.sort_values(['group1', 'proba_de'], ascending=[True, False])
    .groupby('group1')
    .head(top_n)
    .groupby('group1')
    .apply(lambda g: ", ".join(g.index))
    .reset_index()
)

# Rename columns for clarity
top_genes_per_cluster.columns = ["Cluster", "Top 5 Genes/Proteins"]

# Convert to LaTeX table
latex_table = top_genes_per_cluster.to_latex(
    index=False,
    caption="Top 5 marker genes or proteins per cluster based on differential expression analysis.",
    label="tab:top5_markers_de",
    escape=False
)

print(latex_table)


In [None]:
de_df.to_csv("/media/Lynn/data/totalVI/DE_results/differential_expression_results.csv", index=True)

In [None]:
filtered_pro = {}
filtered_rna = {}
cats = rna.obs[TOTALVI_CLUSTERS_KEY].cat.categories
for c in cats:
    cid = f"{c} vs Rest"
    cell_type_df = de_df.loc[de_df.comparison == cid]
    cell_type_df = cell_type_df.sort_values("lfc_median", ascending=False)

    cell_type_df = cell_type_df[cell_type_df.lfc_median > 0]

    pro_rows = cell_type_df.index.str.contains("protein")
    data_pro = cell_type_df.iloc[pro_rows]
    data_pro = data_pro[data_pro["bayes_factor"] > 0.7]

    data_rna = cell_type_df.iloc[~pro_rows]
    data_rna = data_rna[data_rna["bayes_factor"] > 3]
    data_rna = data_rna[data_rna["non_zeros_proportion1"] > 0.1]

    filtered_pro[c] = data_pro.index.tolist()[:3]
    filtered_rna[c] = data_rna.index.tolist()[:2]

In [None]:
with open("/media/Lynn/data/totalVI/1st_run/filtered_protein_and_rna/new_filtered_pro.json", "w") as f:
    json.dump(filtered_pro, f)
with open("/media/Lynn/data/totalVI/1st_run/filtered_protein_and_rna/new_filtered_rna.json", "w") as f:
    json.dump(filtered_rna, f)

In [None]:
codex_channels = ['DAPI', 'FoxP3', 'aSMA', 'CD4', 'CD8', 'CD31', 
                 'CD11c', 'IFNG', 'Pan-Cytokeratin', 'CD68', 'CD20', 
                 'CD66b', 'TNFa', 'CD45RO', 'CD14', 'CD11b', 'Vimentin', 
                 'CD163', 'PDGFRA', 'CD45', 'CCR7', 'IL10', 'CD38', 'CD69', 
                 'Podoplanin', 'PNAd', 'ECP', 'MPO', 'MIP-3', 'CD16', 'CXCL13'
]

rna.obs.rename(
    columns={name: f"{name}_CDX_protein" for name in codex_channels if name in rna.obs.columns},
    inplace=True)

In [None]:
# 1. Read the new annotation CSV
new_ann = pd.read_csv("/media/Lynn/data/run_2_3_final_annotation.csv")

# Check the columns
print(new_ann.columns)

# Make sure there is a 'cell_id' column and a column with the annotation (e.g., 'annotation')
# Adjust column name below if different
cell_id_col = "cell_id"
annotation_col = "annotation"

# 2. Create a mapping from cell_id -> new annotation
cell_to_annotation = dict(zip(new_ann[cell_id_col], new_ann[annotation_col]))

# 3. Map to adata.obs
rna.obs['joint_annotation'] = rna.obs_names.map(cell_to_annotation)

# 4. Optional: check how many cells got matched
print(f"Number of cells with new annotation: {rna.obs['joint_annotation'].notna().sum()} / {rna.n_obs}")

rna.obs["joint_annotation"] = rna.obs["joint_annotation"].replace("Transit Amplifying Cells (Ileum)", "Other")

In [None]:
protein.obs.rename(
    columns={name: f"{name}_CDX_protein" for name in codex_channels if name in protein.obs.columns},
    inplace=True)

In [None]:
protein.obs['joint_annotation']=rna.obs['joint_annotation']

In [None]:
from itertools import chain

# Flatten feature dicts
rna_features = list(chain.from_iterable(filtered_rna.values()))
protein_features = list(chain.from_iterable(filtered_pro.values()))

# Remove duplicates while preserving order
rna_features = list(dict.fromkeys(rna_features))
protein_features = list(dict.fromkeys(protein_features))

# Remove "_CDX_protein" suffix
protein_features_clean = [f.replace('_protein', '') for f in protein_features]

# Extract protein values as a DataFrame (cells x features)
protein_df = pd.DataFrame(
    mdata['protein'][:, protein_features_clean].X,
    index=mdata['protein'].obs_names,
    columns=protein_features_clean
)

# Make sure the index matches RNA obs
protein_df = protein_df.loc[mdata['rna'].obs_names]

# Add protein columns to RNA obs
for col in protein_df.columns:
    mdata['rna'].obs[col] = protein_df[col]

# Combine features for dotplot
combined_features = rna_features + protein_features_clean

# Dotplot
import scanpy as sc
sc.pl.dotplot(
    rna,
    var_names=combined_features,
    groupby='joint_annotation',
    dendrogram=False,
    standard_scale='var',
    swap_axes=True,
    save='_dendogram_filtered_rna_protein_combined.png'
)


In [None]:
sc.pl.umap(
    rna,
    color=[
        TOTALVI_CLUSTERS_KEY,
        "LCN2", 
        "LEFTY1", 
        "TNFRSF17", 
        "CPA3", 
        "S100A8", 
        "CEACAM7", 
        "FABP2", 
        "FCRL1", 
        "CCR7", 
        "CSF3", 
        "PROX1",
    ],
    legend_loc="on data",
    frameon=False,
    ncols=3,
    layer="denoised_rna",
    wspace=0.2,
    save = '_some_filtered_rna.png'
)

In [None]:
sc.pl.umap(
    rna,
    color=[gene for genes in filtered_rna.values() for gene in genes],
    legend_loc="on data",
    frameon=False,
    ncols=3,
    layer="denoised_rna",
    wspace=0.2,
    save = '_some_filtered_rna.png'
)

In [None]:
muon.pl.embedding(
    mdata,
    basis="rna:X_umap",
    color="rna:slide_str",
    frameon=False,
    ncols=1,
    save ='_by_batch.png'
)

In [None]:
muon.pl.embedding(
    mdata,
    basis="rna:X_umap",
    color=protein.var_names,
    frameon=False,
    ncols=3,
    vmax="p99",
    wspace=0.1,
    layer="denoised_protein",
    save = '_codex_markers.png'
)

In [None]:
muon.pl.embedding(
    mdata,
    basis="rna:X_umap",
    layer="protein_foreground_prob",
    color=protein.var_names,
    frameon=False,
    ncols=3,
    vmax="p99",
    wspace=0.1,
    color_map="cividis",
    save = '_protein_foreground_prob.png'
)

In [None]:
sc.pl.umap(
    rna,
    color= 'response_group',
    frameon=False,
    wspace=0.2,
    title = " ",
    save = '_by_response_group.png'
)

In [None]:
sc.pl.umap(
    rna,
    color= 'time_point',
    frameon=False,
    wspace=0.2,
    title = " ",
    save = '_by_timepoint.png'
)

In [None]:
rna.obs[["year", "patient_ID"]] = rna.obs[["year", "patient_ID"]].astype(str)

In [None]:
sc.pl.umap(
    rna,
    color= 'year',
    frameon=False,
    wspace=0.2,
    title = " ",
    save = '_by_year.png'
)

In [None]:
# Standardize patient IDs: make '3_' and '03_' consistent
def standardize_patient_ids(pid):
    parts = pid.split('_', 1)
    if parts[0].isdigit():
        parts[0] = f"{int(parts[0]):02d}"  # e.g., '3' -> '03'
    return "_".join(parts)

rna.obs["patient_ID"] = rna.obs["patient_ID"].map(standardize_patient_ids)

sc.pl.umap(
    rna,
    color= 'patient_ID',
    frameon=False,
    wspace=0.2,
    title = " ",
    save = '_by_patient.png'
)

In [None]:
sc.pl.umap(
    rna,
    color= 'tissue',
    frameon=False,
    wspace=0.2,
    title = " ",
    save = '_by_tissue.png'
)

In [None]:
# Calculate the number of cells per cluster
cluster_counts = rna.obs[TOTALVI_CLUSTERS_KEY].value_counts()

# Create a new column to store the count of cells for each cluster
rna.obs['cluster_cell_count'] = rna.obs[TOTALVI_CLUSTERS_KEY].map(cluster_counts)

# Plot UMAP with a gradient based on the number of cells per cluster
sc.pl.umap(
    rna,
    color='cluster_cell_count',  # Use the newly created column to represent the number of cells
    wspace=0.4,
    save = '_cell_counts.png'
)

In [None]:
sc.pl.umap(
    rna,
    color=TOTALVI_CLUSTERS_KEY,  # Use the newly created column to represent the number of cells
    wspace=0.4,
    save = '_overlayed_cluster_numbers.png',
    legend_loc = 'on data',
    title = ' '
)

In [None]:
rna.obs["xenium_annotation"] = rna.obs["xenium_annotation"].replace("??", "Other")
sc.pl.umap(
    rna,
    color='xenium_annotation', 
    wspace=0.4,
    title = 'Xenium-only Annotation',
    save = '_with_xenium_annotation.png'
)

In [None]:
# Make sure both columns exist in .obs
if "xenium_annotation" in rna.obs.columns and "xenium_leiden_0.7" in rna.obs.columns:
    rna.obs["xenium_annotation_with_cluster_number"] = (
        rna.obs["xenium_leiden_0.7"].astype(str) + ": " + rna.obs["xenium_annotation"].astype(str)
    )
    
sc.pl.umap(
    rna,
    color='xenium_annotation_with_cluster_number', 
    wspace=0.4,
    save = '_with_xenium_annotation_with_numbers.png'
)

In [None]:
sc.pl.umap(
    rna,
    color='xenium_leiden_0.7', 
    wspace=0.4,
    legend_loc = 'on data',
    save = '_with_overlayed_xenium_only_clustering.png',
    title = ' '
)

In [None]:
# Define the mapping from leiden_0.7 clusters to broader cell subsets
cluster_to_subset_mapping = {
    '0': 'T cells',
    '1': 'Stroma',
    '2': 'B/ Plasma cells',
    '3': 'Epithelium',
    '4': 'Myeloid cells',
    '5': 'Myeloid cells',
    '6': 'Epithelium',
    '7': 'B/ Plasma cells',
    '8': 'Stroma',
    '9': 'T cells',
    '10': 'Epithelium',
    '11': 'Epithelium',
    '12': 'Stroma',
    '13': 'Epithelium',
    '14': 'Myeloid cells',
    '15': 'Epithelium',
    '16': 'Epithelium', 
    '17': 'Unassigned', 
    '18': 'Stroma',
    '19': 'Stroma',
    '20': 'Epithelium',
    '21': 'T cells',
    '22': 'Myeloid cells',
    '23': 'Myeloid cells',
    '24': 'Epithelium',
    '25': 'Epithelium',
    '26': 'Epithelium',
    '27': 'Epithelium',
    '28': 'Myeloid cells',
    '29': 'Myeloid cells',
    '30': 'Myeloid cells',
    '31': 'Unassigned', # Based on '??'
    '32': 'Unassigned', # Based on '??'
    '33': 'Unassigned', # Based on '??'
    '34': 'Unassigned', # Based on '??'
    '35': 'Stroma',
    '36': 'Stroma'
}

In [None]:
# Create the new 'cell_subsets' column by mapping 'leiden_0.7'
rna.obs['xenium_cell_subset'] = rna.obs['xenium_leiden_0.7'].map(cluster_to_subset_mapping)

In [None]:
sc.pl.umap(
    rna,
    color='xenium_cell_subset', 
    wspace=0.4,
    save = '_with_xenium_cell_subsets.png'
)

In [None]:
def plot_umap_with_cluster_labels(adata, color_by='xenium_cell_subset', label_by=TOTALVI_CLUSTERS_KEY, save_name=None):
    """
    Plots UMAP colored by `color_by` and overlays cluster labels from `label_by`.
    
    Parameters:
        adata: AnnData object
        color_by: column in adata.obs to color points
        label_by: column in adata.obs to label clusters
        save_name: filename to save the figure (optional)
    """
    # Create figure
    fig, ax = plt.subplots(figsize=(6,6))
    
    # Plot points colored by cell subset
    sc.pl.umap(
        adata,
        color=color_by,
        ax=ax,
        show=False,
        size=5,
        legend_loc='right margin'
    )
    
    # Get cluster centers for labels
    cluster_means = adata.obsm['X_umap'].copy()
    obs_df = adata.obs[[label_by]].copy()
    obs_df['UMAP1'] = cluster_means[:,0]
    obs_df['UMAP2'] = cluster_means[:,1]
    
    # Compute mean position per cluster
    centers = obs_df.groupby(label_by)[['UMAP1','UMAP2']].mean()
    
    # Overlay cluster labels
    for cluster, row in centers.iterrows():
        ax.text(row['UMAP1'], row['UMAP2'], str(cluster),
                color='black', fontsize=12, fontweight='bold',
                ha='center', va='center')
    
    if save_name:
        plt.savefig(save_name, bbox_inches='tight', dpi=150)
    plt.show()


# Example usage:
plot_umap_with_cluster_labels(
    rna,
    color_by='xenium_cell_subset',
    label_by=TOTALVI_CLUSTERS_KEY,
    save_name='umap_cell_subset_with_clusters.png'
)


### Check weird clusters and delete fake cells

In [None]:
# Get cells with leiden == 22
cells_leiden_22 = rna[rna.obs[TOTALVI_CLUSTERS_KEY] == '22']

# Extract core_ID and slide_ID
result = cells_leiden_22.obs[['core_ID', 'slide_ID']]
result

In [None]:
X3Y5_56777 = rna[(rna.obs['core_ID'] == 'X3Y5') & (rna.obs['slide_str'] == 'ID_0056777')]

In [None]:
subset = X3Y5_56777[X3Y5_56777.obs[TOTALVI_CLUSTERS_KEY] == '22']

In [None]:
# Use raw counts layer if needed
if 'xenium_counts' in subset.layers:
    counts = subset.layers['xenium_counts']
else:
    counts = subset.X

# Calculate total transcripts and genes per cell
subset.obs['n_transcripts'] = counts.sum(axis=1).A1 if hasattr(counts, 'A1') else counts.sum(axis=1)
subset.obs['n_genes'] = (counts > 0).sum(axis=1).A1 if hasattr(counts, 'A1') else (counts > 0).sum(axis=1)

# Grouping variables
grouping_vars = ['slide_ID', 'core_ID', 'patient_ID', 'time_point', 'response_group', 'year', 'tissue']

summary_stats = {}

for var in grouping_vars:
    grouped = subset.obs.groupby(var)
    stats = grouped.agg(
        total_cells=('n_transcripts', 'count'),
        total_transcripts=('n_transcripts', 'sum'),
        avg_transcripts_per_cell=('n_transcripts', 'mean'),
        avg_genes_per_cell=('n_genes', 'mean'),
    )
    
    df = grouped.agg(
        total_cells=('total_counts', 'count'),
        total_transcripts=('total_counts', 'sum'),
        control_probe_counts=('control_probe_counts', 'sum'),
        control_codeword_counts=('control_codeword_counts', 'sum')
    )

    df['negative_dna_pct'] = 100 * df['control_probe_counts'] / df['total_transcripts']
    df['negative_decoding_pct'] = 100 * df['control_codeword_counts'] / df['total_transcripts']
    
    summary_stats[var] = stats
    
    summary_stats[var] = summary_stats[var].join(df[['negative_dna_pct', 'negative_decoding_pct']])
    
metrics = [
    'total_cells',
    'total_transcripts',
    'avg_transcripts_per_cell',
    'avg_genes_per_cell',
    'negative_dna_pct',
    'negative_decoding_pct'
]

# Select the most relevant QC metrics
relevant_metrics = [
    'avg_transcripts_per_cell',
    'avg_genes_per_cell',
    'negative_dna_pct',
    'negative_decoding_pct'
]

summary_list = []

for var, df in summary_stats.items():
    df_copy = df[relevant_metrics].copy()
    
    # Compute mean and standard deviation across conditions for each metric
    row = {'metadata': var}
    for metric in relevant_metrics:
        row[f'{metric}_mean'] = df_copy[metric].mean()
        row[f'{metric}_std'] = df_copy[metric].std()
    summary_list.append(row)

# Create summary table
summary_table = pd.DataFrame(summary_list)

summary_table

In [None]:
# Count cells per core_ID / slide_ID combination
combo_counts = result.value_counts()  # counts rows by unique combinations

# Show the combination with the most cells
most_common_combo = combo_counts.idxmax()
most_common_count = combo_counts.max()

print(f"The combination with the most cells is {most_common_combo} with {most_common_count} cells.")

In [None]:
# Make a copy so we don't mess with your original AnnData
adata_tmp = X3Y5_56777.copy()

# Replace all cluster labels except '22' with 'other'
adata_tmp.obs['cluster_highlight'] = adata_tmp.obs[TOTALVI_CLUSTERS_KEY].astype(str)
adata_tmp.obs['cluster_highlight'] = adata_tmp.obs['cluster_highlight'].where(
    adata_tmp.obs['cluster_highlight'] == '22',  # keep '22'
    'other'  # replace everything else
)

# Plot with custom palette (grey for "other", bright for '22')
sc.pl.spatial(
    adata_tmp,
    spot_size = 10,
    color='cluster_highlight',
    palette=['red', 'lightgrey'],  # order matches sorted categories
    title = ' ',
    wspace=0.4,
    save='_X3Y5_56777_cluster_22.png'
)

In [None]:
# Get cells with leiden == 18
cells_leiden_18 = rna[rna.obs[TOTALVI_CLUSTERS_KEY] == '18']

# Extract core_ID and slide_ID
result = cells_leiden_18.obs[['core_ID', 'slide_ID']]

# Count cells per core_ID / slide_ID combination and get top 5
top20_combos = result.value_counts().head(20)

print("Top 20 core_ID / slide_ID combinations with the most cells:")
print(top20_combos)

In [None]:
X3Y6_22110 = rna[(rna.obs['core_ID'] == 'X2Y4') & (rna.obs['slide_str'] == 'ID_0022110')]

In [None]:
# Make a copy so we don't mess with your original AnnData
adata_tmp = X3Y6_22110.copy()

# Replace all cluster labels except '22' with 'other'
adata_tmp.obs['cluster_highlight'] = adata_tmp.obs[TOTALVI_CLUSTERS_KEY].astype(str)
adata_tmp.obs['cluster_highlight'] = adata_tmp.obs['cluster_highlight'].where(
    adata_tmp.obs['cluster_highlight'] == '18',  
    'other'  # replace everything else
)

# Plot with custom palette (grey for "other", bright for '22')
sc.pl.spatial(
    adata_tmp,
    spot_size = 10,
    color='cluster_highlight',
    palette=['red', 'lightgrey'],  # order matches sorted categories
    wspace=0.4,
    title = ' ',
    save='_X2Y4_22110_cluster_18.png'
)

In [None]:
# Make sure your cluster labels are strings
clusters_of_interest = ['18', '22']

cells_in_clusters = rna.obs_names[
    rna.obs[TOTALVI_CLUSTERS_KEY].astype(str).isin(clusters_of_interest)
].tolist()

In [None]:
output_path = '/media/Lynn/data/totalVI/1st_run/cells_in_clusters_18_22.txt'

with open(output_path, 'w') as f:
    for cell_id in cells_in_clusters:
        f.write(f"{cell_id}\n")

print(f"Saved {len(cells_in_clusters)} cell IDs to {output_path}")

In [None]:
# Make a copy so we don't mess with your original AnnData
adata_tmp = rna[(rna.obs['core_ID'] == 'X3Y2') & (rna.obs['slide_str'] == 'ID_0022111')]

# Replace all cluster labels except '22' with 'other'
adata_tmp.obs['cluster_highlight'] = adata_tmp.obs[TOTALVI_CLUSTERS_KEY].astype(str)
adata_tmp.obs['cluster_highlight'] = adata_tmp.obs['cluster_highlight'].where(
    adata_tmp.obs['cluster_highlight'] == '26',  
    'other'  # replace everything else
)
coords = adata_tmp.obsm['spatial']

# Get min and max for x and y
x_min, y_min = np.min(coords, axis=0)
x_max, y_max = np.max(coords, axis=0)

print(f"x_min: {x_min}, x_max: {x_max}")
print(f"y_min: {y_min}, y_max: {y_max}")


In [None]:
size = 1000       # side length of the square
x_offset = 300    # move right from the left edge
y_offset = 400    # move down from the top edge

# Compute crop coordinates
x_min_crop = 7389.6 + x_offset
x_max_crop = x_min_crop + size

y_max_crop = 4628 - y_offset
y_min_crop = y_max_crop - size

# Create mask
coords = adata_tmp.obsm['spatial']
mask = (
    (coords[:, 0] >= x_min_crop) & (coords[:, 0] <= x_max_crop) &
    (coords[:, 1] >= y_min_crop) & (coords[:, 1] <= y_max_crop)
)

# Crop AnnData
adata_cropped = adata_tmp[mask].copy()

In [None]:
# Plot with custom palette (grey for "other")
sc.pl.spatial(
    adata_cropped,
    spot_size = 10,
    color='cluster_highlight',
    palette=['red', 'lightgrey'],  # order matches sorted categories
    wspace=0.4,
    title = ' ',
    save='_fake_cluster_26.png'
)

## Add new annotation

In [None]:
# 1. Read the new annotation CSV
new_ann = pd.read_csv("/media/Lynn/data/run_2_3_final_annotation.csv")

# Check the columns
print(new_ann.columns)

# Make sure there is a 'cell_id' column and a column with the annotation (e.g., 'annotation')
# Adjust column name below if different
cell_id_col = "cell_id"
annotation_col = "annotation"

# 2. Create a mapping from cell_id -> new annotation
cell_to_annotation = dict(zip(new_ann[cell_id_col], new_ann[annotation_col]))

# 3. Map to adata.obs
rna.obs['joint_annotation'] = rna.obs_names.map(cell_to_annotation)

# 4. Optional: check how many cells got matched
print(f"Number of cells with new annotation: {rna.obs['joint_annotation'].notna().sum()} / {rna.n_obs}")

rna.obs["joint_annotation"] = rna.obs["joint_annotation"].replace("Transit Amplifying Cells (Ileum)", "Other")


In [None]:
rna.obs["joint_annotation"] = rna.obs["joint_annotation"].replace("M-Cells (?)", "M-Cells")

In [None]:
sc.pl.umap(
    rna,
    color='joint_annotation', 
    wspace=0.4,
    title = ' ',
    save = '_with_joint_annotation.png'
)