In [None]:
import spatialdata as sd
import spatialdata_plot
from spatialdata_io import xenium, codex
from pathlib import Path
import shutil
import sopa
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import squidpy as sq
import anndata as ad
import pandas as pd

import sys
import os
sys.path.append(os.path.abspath('../src'))
import load_sdata, merge_xenium_sdata, qc

## Read the spatialdata

### Xenium

In [None]:
# Load each Xenium dataset into a SpatialData object
sdata_xenium_first_slide = load_sdata.get_xenium_slide_data('0022110')
sdata_xenium_second_slide = load_sdata.get_xenium_slide_data('0022111')

In [None]:
sdata_xenium_first_slide['column_1']

### CODEX

#### Slide by slide

In [None]:
sdata_codex_first_slide = load_sdata.get_codex_slide_data('0022110')
sdata_codex_second_slide = load_sdata.get_codex_slide_data('0022111')

In [None]:
sdata_codex_first_slide

#### Column by column

In [None]:
sdata_codex_updated_s1 = load_sdata.get_codex_updated_columns_data('0022110')
sdata_codex_updated_s2 = load_sdata.get_codex_updated_columns_data('0022111')

## Explore the Data

### Xenium

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(20, 20))
axes = axes.flatten()

sdata_xenium_first_slide['column_1'].pl.render_images("morphology_focus").pl.show(ax=axes[0], title="Column 1")
sdata_xenium_first_slide['column_2'].pl.render_images("morphology_focus").pl.show(ax=axes[1], title="Column 2")
sdata_xenium_first_slide['column_3'].pl.render_images("morphology_focus").pl.show(ax=axes[2], title="Column 3")
sdata_xenium_first_slide['column_4'].pl.render_images("morphology_focus").pl.show(ax=axes[3], title="Column 4")

# Save the figure
plt.savefig('/media/Lynn/pictures/xenium_columns.png', dpi=300, bbox_inches='tight')

#### Check how many transcripts are unassigned v/s assigned to a cell in column 2

In [None]:
transcripts_c2 = sdata_xenium_first_slide['column_2']["transcripts"].compute()

total_unassigned = (transcripts_c2["cell_id"] == "UNASSIGNED").sum()
total_assigned = len(transcripts_c2) - total_unassigned

print(f"Transcripts with UNASSIGNED cell_id: {total_unassigned}")
print(f"Transcripts with assigned cell_id: {total_assigned}")

#### Visualize only 3rd spot from column 2

In [None]:
crop0 = lambda x: sd.bounding_box_query(
    x, 
    min_coordinate=[1000, 39000], 
    max_coordinate=[12500, 47000], 
    axes=("x", "y"), 
    target_coordinate_system="global",
)

spot_3_column_2 = crop0(sdata_xenium_first_slide['column_2'])

# Create the plot and capture the figure
fig = spot_3_column_2.pl.render_images("morphology_focus").pl.show(figsize=(10, 10))

# Save using the figure object
if fig is not None:
    fig.savefig('/media/Lynn/pictures/Data_Integration/cropped_xenium_morphology_image.png', 
                dpi=300, bbox_inches='tight')
else:
    plt.savefig('/media/Lynn/pictures/Data_Integration/cropped_xenium_morphology_image.png', 
                dpi=300, bbox_inches='tight')

In [None]:
spot_3_column_2.pl.render_labels().pl.show() # segmentation mask

In [None]:
transcripts_s3_c2 = spot_3_column_2["transcripts"].compute()

total_unassigned = (transcripts_s3_c2["cell_id"] == "UNASSIGNED").sum()
total_assigned = len(transcripts_s3_c2) - total_unassigned

print(f"Transcripts with UNASSIGNED cell_id: {total_unassigned}")
print(f"Transcripts with assigned cell_id: {total_assigned}")

#### See what the dots outside of the shape are

In [None]:
crop01 = lambda x: sd.bounding_box_query(
    x,
    min_coordinate=[10000, 46000],
    max_coordinate=[12000, 47000],
    axes=("x", "y"),
    target_coordinate_system="global",
)

crop01(sdata_xenium_first_slide[3]).pl.render_images("morphology_focus").pl.show() 

In [None]:
crop01(sdata_xenium_first_slide['column_2'])["transcripts"].compute()

#### Zoom on a specific area

In [None]:
crop1 = lambda x: sd.bounding_box_query(
    x,
    min_coordinate=[8000, 43000],
    max_coordinate=[10000, 44000],
    axes=("x", "y"),
    target_coordinate_system="global",
)

spot_3_column_2_zoomed = crop1(sdata_xenium_first_slide['column_2'])

# Create the plot and capture the figure
fig = spot_3_column_2_zoomed.pl.render_images("morphology_focus").pl.show(figsize=(10, 10))

# Save using the figure object
if fig is not None:
    fig.savefig('/media/Lynn/pictures/Data_Exploration/cropped_xenium_morphology_zoomed.png', 
                dpi=300, bbox_inches='tight')
else:
    plt.savefig('/media/Lynn/pictures/Data_Exploration/cropped_xenium_morphology_zoomed.png', 
                dpi=300, bbox_inches='tight')

In [None]:
fig = spot_3_column_2_zoomed.pl.render_labels().pl.show(figsize=(10, 10))

if fig is not None:
    fig.savefig('/media/Lynn/pictures/Data_Exploration/cropped_xenium_labels_zoomed.png', 
                dpi=300, bbox_inches='tight')
else:
    plt.savefig('/media/Lynn/pictures/Data_Exploration/cropped_xenium_labels_zoomed.png', 
                dpi=300, bbox_inches='tight')

In [None]:
spot_3_column_2_zoomed.pl.render_images("morphology_focus", channel=0).pl.show(title="DAPI (nuclear)", figsize=(10, 10))
spot_3_column_2_zoomed.pl.render_images("morphology_focus", channel=1).pl.show(
    title="AlphaSMA/Vimentin (interior - protein)", figsize=(10, 10)
)
spot_3_column_2_zoomed.pl.render_images("morphology_focus", channel=2).pl.show(
    title="ATP1A1/CD45/E-Cadherin (boundary)", figsize=(10, 10)
)
spot_3_column_2_zoomed.pl.render_images("morphology_focus", channel=3).pl.show(title="18S (interior - RNA)", figsize=(10, 10))

In [None]:
spot_3_column_2_zoomed.tables['table'].obs

#### Normalize and select highly variable genes

In [None]:
sc.pp.normalize_total(spot_3_column_2_zoomed.tables["table"])
sc.pp.log1p(spot_3_column_2_zoomed.tables["table"])
sc.pp.highly_variable_genes(spot_3_column_2_zoomed.tables["table"])
spot_3_column_2_zoomed.tables["table"].var.sort_values("means")

In [None]:
gene_name = "OLFM4"
spot_3_column_2_zoomed.pl.render_images("morphology_focus").pl.render_shapes(
    "cell_circles",
    color=gene_name,
).pl.show(title=f"{gene_name} expression over Morphology image", coordinate_systems="global", figsize=(10, 5))

##### NCBI Gene Summary for OLFM4 Gene 
This gene was originally cloned from human myeloblasts and found to be selectively expressed in inflammed colonic epithelium. This gene encodes a member of the olfactomedin family. The encoded protein is an antiapoptotic factor that promotes tumor growth and is an extracellular matrix glycoprotein that facilitates cell adhesion. [provided by RefSeq, Mar 2011

##### Gersemann, Michael et al. “Olfactomedin-4 is a glycoprotein secreted into mucus in active IBD.” Journal of Crohn's & colitis vol. 6,4 (2012): 425-34. doi:10.1016/j.crohns.2011.09.013
OLFM4 is overexpressed in active IBD and secreted into mucus. The induction is triggered by bacteria through the Notch pathway and also by the cytokine IL-22. OLFM4 seems to be of functional relevance in IBD as a mucus component, possibly by binding defensins.]

#### Leiden Clustering of all cells in a spot

In [None]:
adata = spot_3_column_2.tables['table']

sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
sc.tl.leiden(adata, resolution=0.1)

sc.pl.umap(adata, color="leiden")

In [None]:
cells_s3_c2 = spot_3_column_2.table.obs
cells_s3_c2["leiden"].value_counts()

### CODEX

In [None]:
print(sdata_codex_first_slide.images['ID_0022110_Scan1.er'])


In [None]:
scale0 = sdata_codex_first_slide.images['ID_0022110_Scan1.er']['/scale0']
img_codex_first_slide = scale0['image']
print(img_codex_first_slide)

#### Visualize each channel (without any image processing)

In [None]:
channel_values = img_codex_first_slide.coords["c"].values
print(channel_values)

In [None]:
# Extract the NumPy array of the image data
img_array = img_codex_first_slide.values

# Get the list of protein names from the 'c' coordinate
protein_names = img_codex_first_slide.coords['c'].values

In [None]:
# Visualize each channel
num_channels = img_array.shape[0]
fig, axes = plt.subplots(5, 6, figsize=(15, 12))  # Adjust layout if needed

for i, ax in enumerate(axes.flat):
    if i < num_channels:
        ax.imshow(img_array[i, :, :], cmap="gray")
        ax.set_title(protein_names[i])  # Set the protein name as the title
        ax.axis("off")

plt.tight_layout()
plt.show()

#### DAPI channel (Normalized)

In [None]:
# Extract the DAPI channel data
dapi_data = img_codex_first_slide.isel(c=0)

# Convert to numpy array for visualization
dapi_array = dapi_data.values

# Normalize the DAPI image data (scaling to 0-1 range)
dapi_array = dapi_array.astype(float)
dapi_array = (dapi_array - np.min(dapi_array)) / (np.max(dapi_array) - np.min(dapi_array))

# Visualize the DAPI channel
plt.figure(figsize=(8, 8))
plt.imshow(dapi_array, cmap="inferno")  
plt.title("DAPI Channel")
plt.axis("off")
plt.show()

#### DAPI channel (Contrast Adjusted)

In [None]:
def contrast_stretch(image, low_perc=2, high_perc=98):
    """Rescale intensity using percentiles to enhance contrast."""
    low, high = np.percentile(image, (low_perc, high_perc))
    return np.clip((image - low) / (high - low), 0, 1)

In [None]:
img_dapi = img_codex_first_slide.sel(c="DAPI").compute().values

# Rescale contrast
img_dapi_rescaled = contrast_stretch(img_dapi.reshape(img_codex_first_slide.sizes["y"], img_codex_first_slide.sizes["x"]))

# Display the rescaled image
plt.figure(figsize=(8, 8))
plt.imshow(img_dapi_rescaled, cmap="inferno")
plt.title("DAPI Channel (Contrast Adjusted)")
plt.axis("off")
plt.show()

#### CD4 channel (Contrast Adjusted)

In [None]:
# Flatten the image array for histogram analysis
img_cd4 = img_codex_first_slide.sel(c="CD4").compute().values.flatten()

# Plot the histogram
plt.figure(figsize=(8, 6))
plt.hist(img_cd4, bins=100, color="blue", alpha=0.7)
plt.xlabel("Pixel Intensity")
plt.ylabel("Frequency")
plt.title("CD4 Channel Intensity Distribution")
plt.yscale("log")  # Log scale to see small values
plt.show()

In [None]:
img_cd4 = img_codex_first_slide.sel(c="CD4").compute().values

# Rescale contrast
img_cd4_rescaled = contrast_stretch(img_cd4.reshape(img_codex_first_slide.sizes["y"], img_codex_first_slide.sizes["x"]))

# Display the rescaled image
plt.figure(figsize=(8, 8))
plt.imshow(img_cd4_rescaled, cmap="inferno")
plt.title("CD4 Channel (Contrast Adjusted)")
plt.axis("off")
plt.show()

#### TNFa channel (Contrast Adjusted)

In [None]:
img_TNFa = img_codex_first_slide.sel(c="TNFa").compute().values

# Rescale contrast
img_TNFa_rescaled = contrast_stretch(img_TNFa.reshape(img_codex_first_slide.sizes["y"], img_codex_first_slide.sizes["x"]))

# Display the rescaled image
plt.figure(figsize=(8, 8))
plt.imshow(img_TNFa_rescaled, cmap="inferno")
plt.title("TNFa Channel (Contrast Adjusted)")
plt.axis("off")
plt.show()

## Compute QC metrics

### Xenium

#### On a column

In [None]:
adata = sdata_xenium_first_slide['column_2'].tables["table"]
adata

In [None]:
qc_metrics= sc.pp.calculate_qc_metrics(adata, percent_top=(10, 20, 50, 150), inplace=True)
adata

##### **Observation-Level QC Metrics (obs)**
These metrics describe properties of individual cells.

###### **1. `n_genes_by_counts`**  
- The number of genes detected in each cell (i.e., genes with nonzero expression).  
- Helps identify low-quality cells, as damaged or dying cells tend to have fewer detected genes.

###### **2. `log1p_n_genes_by_counts`**  
- The log-transformed (`log1p`, meaning log(1 + x)) version of `n_genes_by_counts`.  
- This transformation makes it easier to compare across cells with different transcript levels and reduces the impact of extreme values.

###### **3. `log1p_total_counts`**  
- The log-transformed (`log1p`) version of `total_counts`, which is the sum of all transcript counts per cell.  
- Helps normalize data and visualize expression levels across cells with different total counts.

###### **4. `pct_counts_in_top_10_genes`**  
- The percentage of a cell's total transcript counts that come from its top 10 most expressed genes.  
- High values indicate potential dominance of a few genes, which might suggest technical artifacts or strong cell-type specificity.

---

##### **Gene-Level QC Metrics (var)**
These metrics describe properties of individual genes.

###### **1. `n_cells_by_counts`**  
- The number of cells in which a given gene is detected (i.e., has nonzero expression).  
- Helps filter out genes that are expressed in very few cells, which might be noise or artifacts.

###### **2. `mean_counts`**  
- The average expression level of the gene across all cells.  
- Useful for identifying highly expressed or rarely expressed genes.

###### **3. `log1p_mean_counts`**  
- The log-transformed version of `mean_counts` (`log1p` to handle zeros and make distributions more normal-like).

###### **4. `pct_dropout_by_counts`**  
- The percentage of cells in which a given gene is **not** detected (has a count of zero).  
- High dropout rates can indicate sparse expression or technical issues in detection.

###### **5. `total_counts`**  
- The total expression count for each gene across all cells.  
- Helps assess the global expression of genes.

###### **6. `log1p_total_counts`**  
- The log-transformed version of `total_counts`, making it easier to compare across genes with varying expression levels.

---

##### **How to Use These Metrics for QC**

###### **Cell-Level Filtering**
- Remove cells with very low `n_genes_by_counts` (low complexity, likely dead or low-quality cells).
- Filter out cells with extremely high `total_counts` (potential doublets or technical artifacts).
- Check `pct_counts_in_top_10/20/50/150_genes`; if a few genes dominate, investigate potential biases.

###### **Gene-Level Filtering**
- Remove genes detected in very few cells (`n_cells_by_counts` too low).
- Filter out genes with extremely high `pct_dropout_by_counts`, unless they are expected to be rare markers.

In [None]:
adata.obs['pct_counts_in_top_20_genes']

In [None]:
cprobes = (
    adata.obs["control_probe_counts"].sum() / adata.obs["total_counts"].sum() * 100
)
cwords = (
    adata.obs["control_codeword_counts"].sum() / adata.obs["total_counts"].sum() * 100
)
print(f"Negative DNA probe count % : {cprobes}")
print(f"Negative decoding count % : {cwords}")

##### **cprobes (Negative DNA probe count %):**
Measures what percentage of the total signal comes from negative control DNA probes. A high value suggests strong background noise or non-specific binding. A low value indicates good specificity of the probes.

##### **cwords (Negative decoding count %):**
Measures the proportion of total counts attributed to negative control codewords. A high value may indicate issues with decoding accuracy or high levels of technical noise. A low value suggests a well-decoded dataset with minimal errors.


In [None]:
fig, axs = plt.subplots(1, 4, figsize=(15, 4))

axs[0].set_title("Total transcripts per cell")
sns.histplot(
    adata.obs["total_counts"],
    kde=False,
    ax=axs[0],
)

axs[1].set_title("Unique transcripts per cell")
sns.histplot(
    adata.obs["n_genes_by_counts"],
    kde=False,
    ax=axs[1],
)


axs[2].set_title("Area of segmented cells")
sns.histplot(
    adata.obs["cell_area"],
    kde=False,
    ax=axs[2],
)

axs[3].set_title("Nucleus ratio")
sns.histplot(
    adata.obs["nucleus_area"] / adata.obs["cell_area"],
    kde=False,
    ax=axs[3],
)

##### **1. Total Transcripts per Cell**
- Displays the distribution of total transcript counts per cell (`adata.obs["total_counts"]`).
- Helps identify cells with **very low** or **very high** total transcript counts.
- **Low transcript count cells** may indicate dying or poorly segmented cells.
- **High transcript count cells** may indicate doublets (two or more cells mistakenly segmented as one).

##### **2. Unique Transcripts per Cell**
- Shows the number of unique genes detected per cell (`adata.obs["n_genes_by_counts"]`).
- Cells with **low unique gene counts** could be low-quality or damaged.
- **Cells with abnormally high unique gene counts** could be doublets or poorly segmented regions.

##### **3. Area of Segmented Cells**
- Displays the distribution of cell sizes based on segmentation (`adata.obs["cell_area"]`).
- **Very small cells** could indicate segmentation errors or fragments.
- **Very large cells** could suggest improper merging of multiple cells.

##### **4. Nucleus Ratio**
- Shows the ratio of **nucleus area to total cell area** (`adata.obs["nucleus_area"] / adata.obs["cell_area"]`).
- A **high nucleus-to-cell ratio** could indicate small cytoplasmic areas (common in certain cell types or poorly segmented cells).
- A **low nucleus-to-cell ratio** could mean errors in nucleus segmentation or multinucleated cells.

##### **How to Use This for QC**
- **Check for outliers in each plot** (e.g., extremely high or low values).
- **Decide on filtering thresholds** based on distributions (e.g., removing cells with very low transcript counts or abnormal nucleus ratios).
- **Look for multimodal distributions**, which could indicate subpopulations or segmentation issues.

In [None]:
sc.pp.filter_cells(adata, min_counts=3)
sc.pp.filter_genes(adata, min_cells=3)

In [None]:
adata.layers["counts"] = adata.X.copy()
sc.pp.normalize_total(adata, inplace=True)
sc.pp.log1p(adata)
sc.pp.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)

In [None]:
sc.tl.leiden(adata, resolution =0.4)

In [None]:
sc.pl.umap(
    adata,
    color=[
        "total_counts",
        "n_genes_by_counts",
        "leiden",
    ],
    wspace=0.4,
)

##### **UMAP colored by total_counts:**
This shows the distribution of cells based on their total transcript counts. You can observe if any specific regions or clusters have higher or lower total counts.  High total counts might indicate more highly expressed cells or doublets, and low counts might indicate low-quality cells.

##### **UMAP colored by n_genes_by_counts:**
This visualizes the number of unique genes detected in each cell. It can help identify clusters of cells with higher or lower gene complexity. Cells with few genes detected might be poor quality or dying cells, while cells with more genes detected are likely to be of better quality.

In [None]:
# Calculate the number of cells per cluster
cluster_counts = adata.obs['leiden'].value_counts()

# Create a new column to store the count of cells for each cluster
adata.obs['cluster_cell_count'] = adata.obs['leiden'].map(cluster_counts)

# Plot UMAP with a gradient based on the number of cells per cluster
sc.pl.umap(
    adata,
    color='cluster_cell_count',  # Use the newly created column to represent the number of cells
    wspace=0.4
)

In [None]:
sq.pl.spatial_scatter(
    adata,
    library_id="spatial",
    shape=None,
    color=[
        "leiden",
    ],
    wspace=0.4,
)

In [None]:
sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True)
sq.gr.centrality_scores(adata, cluster_key="leiden")
sq.pl.centrality_scores(adata, cluster_key="leiden", figsize=(16, 5))

**Clustering Coefficient:** Measures how tightly nodes are grouped or clustered together.

**Closeness Centrality:** Measures how quickly a node/group can reach others in the network.

**Degree Centrality:** Measures how well a node/group connects with non-group members in the network.

In [None]:
sdata_xenium_first_slide['column_2'].tables["subsample"] = sc.pp.subsample(adata, fraction=0.5, copy=True)
adata_subsample = sdata_xenium_first_slide['column_2'].tables["subsample"]

sq.gr.co_occurrence(
    adata_subsample,
    cluster_key="leiden",
)
sq.pl.co_occurrence(
    adata_subsample,
    cluster_key="leiden",
    clusters="12",
    figsize=(10, 10),
)

In [None]:
# Check the distance unit
adata_subsample.obsm["spatial"]

The **spatial coordinates** in adata_subsample.obsm["spatial"] appear to be in **microns (µm)** based on their scale. These values are much larger than typical pixel coordinates (which usually range from 0 to a few thousand), suggesting that the dataset is likely using a real-world spatial scale rather than raw image pixels.

#### On a single core

In [None]:
crop0 = lambda x: sd.bounding_box_query(
    x,
    min_coordinate=[1000, 39000],
    max_coordinate=[12500, 47000],
    axes=("x", "y"),
    target_coordinate_system="global",
)

spot_3_column_2 = crop0(sdata_xenium_first_slide['column_2'])

In [None]:
adata_s3_c2 = spot_3_column_2.tables["table"]
adata_s3_c2

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(15, 4))

axs[0].set_title("Total transcripts per cell")
sns.histplot(
    adata_s3_c2.obs["total_counts"],
    kde=False,
    ax=axs[0],
)

axs[1].set_title("Unique transcripts per cell")
sns.histplot(
    adata_s3_c2.obs["n_genes_by_counts"],
    kde=False,
    ax=axs[1],
)


axs[2].set_title("Area of segmented cells")
sns.histplot(
    adata_s3_c2.obs["cell_area"],
    kde=False,
    ax=axs[2],
)

axs[3].set_title("Nucleus ratio")
sns.histplot(
    adata_s3_c2.obs["nucleus_area"] / adata_s3_c2.obs["cell_area"],
    kde=False,
    ax=axs[3],
)

In [None]:
sc.pl.umap(
    adata_s3_c2,
    color=[
        "total_counts",
        "n_genes_by_counts",
        "leiden",
    ],
    wspace=0.4,
)

In [None]:
sq.pl.spatial_scatter(
    adata_s3_c2,
    library_id="spatial",
    shape=None,
    color=[
        "leiden",
    ],
    wspace=0.4,
)

#### Zoom in on an area

In [None]:
crop1 = lambda x: sd.bounding_box_query(
    x,
    min_coordinate=[8000, 43000],
    max_coordinate=[10000, 44000],
    axes=("x", "y"),
    target_coordinate_system="global",
)

spot_3_column_2_zoomed = crop1(sdata_xenium_first_slide['column_2'])
spot_3_column_2_zoomed.pl.render_images("morphology_focus").pl.show() 

In [None]:
adata_s3_c2_zoomed = spot_3_column_2_zoomed.tables["table"]
adata_s3_c2_zoomed

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(15, 4))

axs[0].set_title("Total transcripts per cell")
sns.histplot(
    adata_s3_c2_zoomed.obs["total_counts"],
    kde=False,
    ax=axs[0],
)

axs[1].set_title("Unique transcripts per cell")
sns.histplot(
    adata_s3_c2_zoomed.obs["n_genes_by_counts"],
    kde=False,
    ax=axs[1],
)


axs[2].set_title("Area of segmented cells")
sns.histplot(
    adata_s3_c2_zoomed.obs["cell_area"],
    kde=False,
    ax=axs[2],
)

axs[3].set_title("Nucleus ratio")
sns.histplot(
    adata_s3_c2_zoomed.obs["nucleus_area"] / adata_s3_c2.obs["cell_area"],
    kde=False,
    ax=axs[3],
)

In [None]:
sc.pl.umap(
    adata_s3_c2_zoomed,
    color=[
        "total_counts",
        "n_genes_by_counts",
        "leiden",
    ],
    wspace=0.4,
)

In [None]:
sq.pl.spatial_scatter(
    adata_s3_c2_zoomed,
    library_id="spatial",
    shape=None,
    color=[
        "leiden",
    ],
    wspace=0.4,
)

#### On each full slide

##### Merge Xenium columns

In [None]:
# Merge Xenium columns together
merged_xenium_s1 = merge_xenium_sdata.combine_xenium_columns(sdata_xenium_first_slide)
merged_xenium_s2 = merge_xenium_sdata.combine_xenium_columns(sdata_xenium_second_slide)

##### Concatenate tables

In [None]:
# Concatenate them, preserving origin as a batch key (optional)
merged_table_s1 = merge_xenium_sdata.concatenate_tables(merged_xenium_s1)
merged_table_s2 = merge_xenium_sdata.concatenate_tables(merged_xenium_s2)

In [None]:
merged_table_s1

##### Compute qc metrics

In [None]:
qc_metrics= sc.pp.calculate_qc_metrics(merged_table_s1, percent_top=(10, 20, 50, 150), inplace=True)

In [None]:
qc_metrics_2= sc.pp.calculate_qc_metrics(merged_table_s2, percent_top=(10, 20, 50, 150), inplace=True)

In [None]:
cprobes = (
    merged_table_s1.obs["control_probe_counts"].sum() / merged_table_s1.obs["total_counts"].sum() * 100
)
cwords = (
    merged_table_s1.obs["control_codeword_counts"].sum() / merged_table_s1.obs["total_counts"].sum() * 100
)
print(f"Negative DNA probe count % : {cprobes}")
print(f"Negative decoding count % : {cwords}")

In [None]:
cprobes = (
    merged_table_s2.obs["control_probe_counts"].sum() / merged_table_s2.obs["total_counts"].sum() * 100
)
cwords = (
    merged_table_s2.obs["control_codeword_counts"].sum() / merged_table_s2.obs["total_counts"].sum() * 100
)
print(f"Negative DNA probe count % slide 2: {cprobes}")
print(f"Negative decoding count % slide 2: {cwords}")

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(15, 4))

axs[0].set_title("Total transcripts per cell")
sns.histplot(
    merged_table_s1.obs["total_counts"],
    kde=False,
    ax=axs[0],
)

axs[1].set_title("Unique transcripts per cell")
sns.histplot(
    merged_table_s1.obs["n_genes_by_counts"],
    kde=False,
    ax=axs[1],
)


axs[2].set_title("Area of segmented cells")
sns.histplot(
    merged_table_s1.obs["cell_area"],
    kde=False,
    ax=axs[2],
)

axs[3].set_title("Nucleus ratio")
sns.histplot(
    merged_table_s1.obs["nucleus_area"] / merged_table_s1.obs["cell_area"],
    kde=False,
    ax=axs[3],
)

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(15, 4))

axs[0].set_title("Total transcripts per cell s2")
sns.histplot(
    merged_table_s2.obs["total_counts"],
    kde=False,
    ax=axs[0],
)

axs[1].set_title("Unique transcripts per cell s2")
sns.histplot(
    merged_table_s2.obs["n_genes_by_counts"],
    kde=False,
    ax=axs[1],
)


axs[2].set_title("Area of segmented cells s2")
sns.histplot(
    merged_table_s2.obs["cell_area"],
    kde=False,
    ax=axs[2],
)

axs[3].set_title("Nucleus ratio s2")
sns.histplot(
    merged_table_s2.obs["nucleus_area"] / merged_table_s2.obs["cell_area"],
    kde=False,
    ax=axs[3],
)

##### Filter low count cells and genes

In [None]:
sc.pp.filter_cells(merged_table_s1, min_counts=3)
sc.pp.filter_genes(merged_table_s1, min_cells=3)

In [None]:
sc.pp.filter_cells(merged_table_s2, min_counts=3)
sc.pp.filter_genes(merged_table_s2, min_cells=3)

##### Cluster cells

###### **Slide 1**

In [None]:
merged_table_s1.layers["counts"] = merged_table_s1.X.copy()
sc.pp.normalize_total(merged_table_s1, target_sum=100, inplace=True)
sc.pp.log1p(merged_table_s1)
sc.pp.pca(merged_table_s1)
sc.pp.neighbors(merged_table_s1)
sc.tl.umap(merged_table_s1)

In [None]:
sc.tl.leiden(merged_table_s1, resolution = 1)

In [None]:
sc.pl.umap(
    merged_table_s1,
    color=[
        "total_counts",
        "n_genes_by_counts",
        "leiden",
    ],
    wspace=0.4,
    save = '_xenium_s1_res_1.png'
)

In [None]:
# Calculate the number of cells per cluster
cluster_counts = merged_table_s1.obs['leiden'].value_counts()

# Create a new column to store the count of cells for each cluster
merged_table_s1.obs['cluster_cell_count'] = merged_table_s1.obs['leiden'].map(cluster_counts)

# Plot UMAP with a gradient based on the number of cells per cluster
sc.pl.umap(
    merged_table_s1,
    color='cluster_cell_count',  # Use the newly created column to represent the number of cells
    wspace=0.4,
    save = '_xenium_s1_res_1_cell_count.png'
)

###### **Slide 2**

In [None]:
merged_table_s2.layers["counts"] = merged_table_s2.X.copy()
sc.pp.normalize_total(merged_table_s2, target_sum = 100, inplace=True)
sc.pp.log1p(merged_table_s2)
sc.pp.pca(merged_table_s2)
sc.pp.neighbors(merged_table_s2)
sc.tl.umap(merged_table_s2)

In [None]:
sc.tl.leiden(merged_table_s2, resolution =1)

In [None]:
sc.pl.umap(
    merged_table_s2,
    color=[
        "total_counts",
        "n_genes_by_counts",
        "leiden",
    ],
    wspace=0.4,
    save = '_xenium_s2_res_1.png'
)

In [None]:
# Calculate the number of cells per cluster
cluster_counts = merged_table_s2.obs['leiden'].value_counts()

# Create a new column to store the count of cells for each cluster
merged_table_s2.obs['cluster_cell_count'] = merged_table_s2.obs['leiden'].map(cluster_counts)

# Plot UMAP with a gradient based on the number of cells per cluster
sc.pl.umap(
    merged_table_s2,
    color='cluster_cell_count',  # Use the newly created column to represent the number of cells
    wspace=0.4,
    save = '_xenium_s2_res_1_cell_count.png'
)

##### Compute centrality scores

In [None]:
sq.gr.spatial_neighbors(merged_table_s1, coord_type="generic", delaunay=True)
sq.gr.centrality_scores(merged_table_s1, cluster_key="leiden")
sq.pl.centrality_scores(merged_table_s1, cluster_key="leiden", figsize=(16, 5))

In [None]:
sq.gr.spatial_neighbors(merged_table_s2, coord_type="generic", delaunay=True)
sq.gr.centrality_scores(merged_table_s2, cluster_key="leiden")
sq.pl.centrality_scores(merged_table_s2, cluster_key="leiden", figsize=(16, 5))

##### Co-occurence graph

In [None]:
merged_table_s1_subsample = sc.pp.subsample(merged_table_s1, fraction=0.5, copy=True)

sq.gr.co_occurrence(
    merged_table_s1_subsample,
    cluster_key="leiden",
)
sq.pl.co_occurrence(
    merged_table_s1_subsample,
    cluster_key="leiden",
    clusters="12",
    figsize=(10, 10),
)

In [None]:
merged_table_s2_subsample = sc.pp.subsample(merged_table_s2, fraction=0.5, copy=True)

sq.gr.co_occurrence(
    merged_table_s2_subsample,
    cluster_key="leiden",
)
sq.pl.co_occurrence(
    merged_table_s2_subsample,
    cluster_key="leiden",
    clusters="12",
    figsize=(10, 10),
)

#### On both slides together

In [None]:
adata_xenium_both_slides = merge_xenium_sdata.merge_xenium_slides_tables(merged_table_s1, merged_table_s2)

In [None]:
qc_metrics= sc.pp.calculate_qc_metrics(adata_xenium_both_slides, percent_top=(10, 20, 50, 150), inplace=True)

In [None]:
cprobes = (
    adata_xenium_both_slides.obs["control_probe_counts"].sum() / adata_xenium_both_slides.obs["total_counts"].sum() * 100
)
cwords = (
    adata_xenium_both_slides.obs["control_codeword_counts"].sum() / adata_xenium_both_slides.obs["total_counts"].sum() * 100
)
print(f"Negative DNA probe count % : {cprobes}")
print(f"Negative decoding count % : {cwords}")

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(15, 4))

axs[0].set_title("Total transcripts per cell")
sns.histplot(
    adata_xenium_both_slides.obs["total_counts"],
    kde=False,
    ax=axs[0],
)

axs[1].set_title("Unique transcripts per cell")
sns.histplot(
    adata_xenium_both_slides.obs["n_genes_by_counts"],
    kde=False,
    ax=axs[1],
)


axs[2].set_title("Area of segmented cells")
sns.histplot(
    adata_xenium_both_slides.obs["cell_area"],
    kde=False,
    ax=axs[2],
)

axs[3].set_title("Nucleus ratio")
sns.histplot(
    adata_xenium_both_slides.obs["nucleus_area"] / adata_xenium_both_slides.obs["cell_area"],
    kde=False,
    ax=axs[3],
)

In [None]:
sc.pp.filter_cells(adata_xenium_both_slides, min_counts=3)
sc.pp.filter_genes(adata_xenium_both_slides, min_cells=3)

In [None]:
adata_xenium_both_slides.layers["counts"] = adata_xenium_both_slides.X.copy()
sc.pp.normalize_total(adata_xenium_both_slides, target_sum = 100, inplace=True)
sc.pp.log1p(adata_xenium_both_slides)
sc.pp.pca(adata_xenium_both_slides)
sc.pp.neighbors(adata_xenium_both_slides)
sc.tl.umap(adata_xenium_both_slides)

In [None]:
sc.tl.leiden(adata_xenium_both_slides, resolution =1)

In [None]:
sc.pl.umap(
    adata_xenium_both_slides,
    color=[
        "total_counts",
        "n_genes_by_counts",
        "leiden",
    ],
    wspace=0.4,
    save = '_xenium_total_res_1.png'
)

In [None]:
# Calculate the number of cells per cluster
cluster_counts = adata_xenium_both_slides.obs['leiden'].value_counts()

# Create a new column to store the count of cells for each cluster
adata_xenium_both_slides.obs['cluster_cell_count'] = adata_xenium_both_slides.obs['leiden'].map(cluster_counts)

# Plot UMAP with a gradient based on the number of cells per cluster
sc.pl.umap(
    adata_xenium_both_slides,
    color='cluster_cell_count',  # Use the newly created column to represent the number of cells
    wspace=0.4,
    save = '_xenium_total_res_1_cell_count.png'
)

##### Change clustering resolution to 0.3

In [None]:
sc.tl.leiden(adata_xenium_both_slides, resolution =0.3)

In [None]:
sc.pl.umap(
    adata_xenium_both_slides,
    color=[
        "total_counts",
        "n_genes_by_counts",
        "leiden",
    ],
    wspace=0.4,
)

In [None]:
# Calculate the number of cells per cluster
cluster_counts = adata_xenium_both_slides.obs['leiden'].value_counts()

# Create a new column to store the count of cells for each cluster
adata_xenium_both_slides.obs['cluster_cell_count'] = adata_xenium_both_slides.obs['leiden'].map(cluster_counts)

# Plot UMAP with a gradient based on the number of cells per cluster
sc.pl.umap(
    adata_xenium_both_slides,
    color='cluster_cell_count',  # Use the newly created column to represent the number of cells
    wspace=0.4
)

##### Change clustering resolution to 0.1

In [None]:
sc.tl.leiden(adata_xenium_both_slides, resolution =0.1)

In [None]:
sc.pl.umap(
    adata_xenium_both_slides,
    color=[
        "total_counts",
        "n_genes_by_counts",
        "leiden",
    ],
    wspace=0.4,
)

In [None]:
# Calculate the number of cells per cluster
cluster_counts = adata_xenium_both_slides.obs['leiden'].value_counts()

# Create a new column to store the count of cells for each cluster
adata_xenium_both_slides.obs['cluster_cell_count'] = adata_xenium_both_slides.obs['leiden'].map(cluster_counts)

# Plot UMAP with a gradient based on the number of cells per cluster
sc.pl.umap(
    adata_xenium_both_slides,
    color='cluster_cell_count',  # Use the newly created column to represent the number of cells
    wspace=0.4
)

### CODEX

#### Plot mean channel intensities for both slides

In [None]:
codex_image_s2 = sdata_codex_second_slide.images['ID_0022111_Scan1.er']['scale0'].data_vars['image']

In [None]:
codex_image_s1 = sdata_codex_first_slide.images['ID_0022110_Scan1.er']['scale0'].data_vars['image']
codex_image_s1

In [None]:
foxp3_image_s1 = codex_image_s1.sel(c='FoxP3')
foxp3_image_s1

In [None]:
foxp3_array_s1 = foxp3_image_s1.data.compute()
foxp3_array_s1

In [None]:
means_slide1 = qc.compute_codex_channel_means(codex_image_s1)

In [None]:
means_slide2 = qc.compute_codex_channel_means(codex_image_s2)

In [None]:
channel_names = list(means_slide1.keys())
intensities1 = [means_slide1[c] for c in channel_names]
intensities2 = [means_slide2[c] for c in channel_names]

x = np.arange(len(channel_names))
plt.figure(figsize=(12, 4))
plt.bar(x - 0.2, intensities1, width=0.4, label='Slide 1')
plt.bar(x + 0.2, intensities2, width=0.4, label='Slide 2')
plt.xticks(x, channel_names, rotation=90)
plt.ylabel('Mean Intensity')
plt.title('Channel Intensity Comparison')
plt.legend()
plt.tight_layout()
plt.savefig('/media/Lynn/pictures/Data_Exploration/codex_channel_intensities.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
from scipy.stats import pearsonr

corr, pval = pearsonr(intensities1, intensities2)
print(f"Correlation between slides: {corr:.3f} (p={pval:.2e})")

##### Normalize intensity means to DAPI

In [None]:
normalized_slide1 = {k: v / means_slide1['DAPI'] for k, v in means_slide1.items()}
normalized_slide2 = {k: v / means_slide2['DAPI'] for k, v in means_slide2.items()}

In [None]:
# Keep the same channel order
channel_names = list(means_slide1.keys())

norm1_vals = [normalized_slide1[c] for c in channel_names]
norm2_vals = [normalized_slide2[c] for c in channel_names]

# Plot
x = np.arange(len(channel_names))
plt.figure(figsize=(14, 4))
plt.bar(x - 0.2, norm1_vals, width=0.4, label='Slide 1 (Normalized)')
plt.bar(x + 0.2, norm2_vals, width=0.4, label='Slide 2 (Normalized)')
plt.xticks(x, channel_names, rotation=90)
plt.ylabel('Normalized Mean Intensity (by DAPI)')
plt.title('Channel Intensity Comparison (DAPI-normalized)')
plt.legend()
plt.tight_layout()
plt.savefig('/media/Lynn/pictures/Data_Exploration/codex_channel_intensities_normalized.png', dpi=300, bbox_inches='tight')
plt.show()

normalize by the nbr of cores detected

plot the differences in signal instead of the boxplot

use target_sum = 100 in normalizaton for clustering. also try harsher cell filtering
plot back the clusters spatially, but this time on top of the morphology image!

add metadata

In [None]:
diff_percent = {
    ch: abs(normalized_slide1[ch] - normalized_slide2[ch]) /
        ((normalized_slide1[ch] + normalized_slide2[ch]) / 2)
    for ch in normalized_slide1
}

# Sort and display top differing channels
top_diffs = sorted(diff_percent.items(), key=lambda x: x[1], reverse=True)
for ch, diff in top_diffs[:5]:
    print(f"{ch}: {diff*100:.1f}% difference")

| Marker           | % Difference | Notes                                                                 |
|------------------|--------------|-----------------------------------------------------------------------|
| Pan-Cytokeratin  | 66.7%        | Could indicate epithelial content differences or staining variation.                              |
| FoxP3            | 60.2%        | Low signal marker; Tregs are sparse and can vary widely.                                          |
| CD45RO           | 44.3%        | Tissue-resident Leukocytes, Memory T cell marker; variation could be biological or technical.     |
| aSMA             | 44.0%        | Stromal marker; may reflect tissue composition or staining.                                       |
| TNFa             | 41.8%        | Inflammatory marker, Tissue staining didn't work.                                                 |

##### Normalize intensity means to number of cores

In [None]:
channel_names = list(means_slide1.keys())
intensities1 = [means_slide1[c] for c in channel_names]
intensities2 = [means_slide2[c] for c in channel_names]

# Number of cores per slide
n_cores_slide1 = 27
n_cores_slide2 = 31

# Normalize intensities by number of cores
normalized_intensities1 = [intensity / n_cores_slide1 for intensity in intensities1]
normalized_intensities2 = [intensity / n_cores_slide2 for intensity in intensities2]

x = np.arange(len(channel_names))
plt.figure(figsize=(12, 4))
plt.bar(x - 0.2, normalized_intensities1, width=0.4, label='Slide 1 (27 cores)')
plt.bar(x + 0.2, normalized_intensities2, width=0.4, label='Slide 2 (31 cores)')
plt.xticks(x, channel_names, rotation=90)
plt.ylabel('Mean Intensity per Core')
plt.title('Channel Intensity Comparison (normalized by slide core number)')
plt.legend()
plt.tight_layout()
plt.savefig('/media/Lynn/pictures/Data_Exploration/codex_channel_intensities_normalized_by_core_nbr.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Calculate percentage differences between normalized intensities
diff_percent = {}
for i, ch in enumerate(channel_names):
    intensity1 = normalized_intensities1[i]
    intensity2 = normalized_intensities2[i]
    
    # Calculate percentage difference using the average as denominator
    avg_intensity = (intensity1 + intensity2) / 2
    if avg_intensity > 0:  # Avoid division by zero
        diff_percent[ch] = abs(intensity1 - intensity2) / avg_intensity
    else:
        diff_percent[ch] = 0

# Sort and display top differing channels
top_diffs = sorted(diff_percent.items(), key=lambda x: x[1], reverse=True)
print("Top 5 channels with largest percentage differences:")
for ch, diff in top_diffs[:5]:
    print(f"{ch}: {diff*100:.1f}% difference")

##### Plot channel differences between slides

In [None]:
# Create heatmap
plt.figure(figsize=(16, 4))
data_matrix = np.array([unnorm_diffs, dapi_diffs, core_diffs])
x = np.arange(len(channel_names))

im = plt.imshow(data_matrix, cmap='YlOrRd', aspect='auto')
plt.xticks(x, channel_names, rotation=90)
plt.yticks([0, 1, 2], ['Unnormalized', 'DAPI-normalized', 'Core-normalized'])
plt.title('Channel Intensity Percentage Differences Between Slides')
plt.colorbar(im, label='Percentage Difference (%)')
plt.tight_layout()
plt.savefig('/media/Lynn/pictures/Data_Exploration/codex_channel_differences_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

##### Compute TOTAL intensities (NOT MEANS) THEN NORMALIZE by cell number

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict
import logging

logger = logging.getLogger(__name__)

def compute_codex_channel_totals(image_data, 
                               downsample_factor: int = 10) -> Dict[str, float]:
    """
    Compute the total intensity for each channel in CODEX imaging data.
    
    Parameters
    ----------
    image_data : xr.DataArray
        CODEX image data with dimensions including 'c' for channels.
    downsample_factor : int, default=10
        Factor by which to downsample the image to reduce memory usage.
        Higher values result in faster computation but less precise totals.
        
    Returns
    -------
    Dict[str, float]
        Dictionary mapping channel names to their total intensity values.
        
    Notes
    -----
    This function calculates total intensity values for each channel in CODEX data,
    using downsampling to manage memory usage when processing large images.
    The total is scaled up to account for downsampling.
    """
    totals = {}
        
    if 'c' not in image_data.coords:
        raise ValueError("Input image_data missing 'c' coordinate for channels")
    
    # Calculate scaling factor for downsampling
    scaling_factor = downsample_factor ** 2
    
    for c in image_data.coords['c'].values:
        # Sample pixels to avoid loading entire slide
        channel_data = image_data.sel(c=c)
        img = channel_data.data[::downsample_factor, ::downsample_factor]
        
        # Compute total and scale up for downsampling
        downsampled_total = float(img.sum().compute().item())
        scaled_total = downsampled_total * scaling_factor
        
        totals[str(c)] = scaled_total
        logger.info(f"Processed channel {c}, total intensity: {totals[str(c)]:.0f}")
        
    return totals

def plot_normalized_channel_intensities(totals_slide1: Dict[str, float],
                                      totals_slide2: Dict[str, float],
                                      n_cells_slide1: int,
                                      n_cells_slide2: int,
                                      save_path: str = None) -> None:
    """
    Plot channel intensities normalized by cell count.
    
    Parameters
    ----------
    totals_slide1, totals_slide2 : Dict[str, float]
        Dictionaries mapping channel names to total intensity values.
    n_cells_slide1, n_cells_slide2 : int
        Number of cells in each slide.
    save_path : str, optional
        Path to save the plot. If None, plot is not saved.
    """
    # Extract channel names and intensities
    channel_names = list(totals_slide1.keys())
    total_intensities1 = [totals_slide1[c] for c in channel_names]
    total_intensities2 = [totals_slide2[c] for c in channel_names]
    
    # Normalize by cell count to get mean intensity per cell
    mean_per_cell1 = [total / n_cells_slide1 for total in total_intensities1]
    mean_per_cell2 = [total / n_cells_slide2 for total in total_intensities2]
    
    # Create comparison plot
    x = np.arange(len(channel_names))
    plt.figure(figsize=(14, 6))
    
    plt.bar(x - 0.2, mean_per_cell1, width=0.4, 
            label=f'Slide 1 ({n_cells_slide1:,} cells)')
    plt.bar(x + 0.2, mean_per_cell2, width=0.4, 
            label=f'Slide 2 ({n_cells_slide2:,} cells)')
    
    plt.xticks(x, channel_names, rotation=45, ha='right')
    plt.ylabel('Mean Intensity per Cell')
    plt.title('Channel Intensity Comparison (Normalized by Cell Count)')
    plt.legend()
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        logger.info(f"Plot saved to {save_path}")
    
    plt.show()
    
    # Print normalization statistics
    print("\nNormalization Summary:")
    print(f"Slide 1: {n_cells_slide1:,} cells")
    print(f"Slide 2: {n_cells_slide2:,} cells")
    print(f"Cell count ratio (Slide 2/Slide 1): {n_cells_slide2/n_cells_slide1:.2f}")
    
    # Show some example comparisons
    print(f"\nExample channel comparisons:")
    for i, channel in enumerate(channel_names[:3]):  # Show first 3 channels
        print(f"{channel}:")
        print(f"  Total intensities: {total_intensities1[i]:.0f} vs {total_intensities2[i]:.0f}")
        print(f"  Per-cell means: {mean_per_cell1[i]:.4f} vs {mean_per_cell2[i]:.4f}")


# 1. Compute total intensities for both slides
totals_slide1 = compute_codex_channel_totals(codex_image_s1, downsample_factor=10)
totals_slide2 = compute_codex_channel_totals(codex_image_s2, downsample_factor=10)

# 2. Plot normalized comparison
plot_normalized_channel_intensities(
    totals_slide1, 
    totals_slide2, 
    n_cells_slide1=426443, 
    n_cells_slide2=476417,
    save_path='/media/Lynn/pictures/Data_Exploration/codex_channel_intensities_normalized_by_cell_nbr.png')

In [None]:
def analyze_channel_differences(totals_slide1: Dict[str, float],
                               totals_slide2: Dict[str, float],
                               n_cells_slide1: int,
                               n_cells_slide2: int,
                               top_n: int = 10) -> Dict[str, float]:
    """
    Analyze percentage differences between normalized channel intensities.
    
    Parameters
    ----------
    totals_slide1, totals_slide2 : Dict[str, float]
        Dictionaries mapping channel names to total intensity values.
    n_cells_slide1, n_cells_slide2 : int
        Number of cells in each slide.
    top_n : int, default=10
        Number of top differing channels to display.
        
    Returns
    -------
    Dict[str, float]
        Dictionary mapping channel names to their percentage differences.
    """
    
    # Get channel names (assuming both slides have same channels)
    channel_names = list(totals_slide1.keys())
    
    # Normalize intensities by cell count
    normalized_intensities1 = [totals_slide1[ch] / n_cells_slide1 for ch in channel_names]
    normalized_intensities2 = [totals_slide2[ch] / n_cells_slide2 for ch in channel_names]
    
    # Calculate percentage differences between normalized intensities
    diff_percent = {}
    diff_absolute = {}
    
    for i, ch in enumerate(channel_names):
        intensity1 = normalized_intensities1[i]
        intensity2 = normalized_intensities2[i]
        
        # Calculate percentage difference using the average as denominator
        avg_intensity = (intensity1 + intensity2) / 2
        if avg_intensity > 0:  # Avoid division by zero
            diff_percent[ch] = abs(intensity1 - intensity2) / avg_intensity
        else:
            diff_percent[ch] = 0
            
        # Store absolute difference for additional context
        diff_absolute[ch] = abs(intensity1 - intensity2)
    
    # Sort and display top differing channels
    top_diffs = sorted(diff_percent.items(), key=lambda x: x[1], reverse=True)
    
    print(f"Top {min(top_n, len(top_diffs))} channels with largest percentage differences:")
    print("=" * 70)
    print(f"{'Channel':<15} {'% Diff':<8} {'Slide 1':<12} {'Slide 2':<12} {'Abs Diff':<10}")
    print("-" * 70)
    
    for ch, diff in top_diffs[:top_n]:
        idx = channel_names.index(ch)
        intensity1 = normalized_intensities1[idx]
        intensity2 = normalized_intensities2[idx]
        abs_diff = diff_absolute[ch]
        
        print(f"{ch:<15} {diff*100:<7.1f}% {intensity1:<12.4f} {intensity2:<12.4f} {abs_diff:<10.4f}")
    
    # Summary statistics
    print(f"\nSummary Statistics:")
    print(f"Mean percentage difference: {np.mean(list(diff_percent.values()))*100:.1f}%")
    print(f"Median percentage difference: {np.median(list(diff_percent.values()))*100:.1f}%")
    print(f"Max percentage difference: {max(diff_percent.values())*100:.1f}%")
    print(f"Min percentage difference: {min(diff_percent.values())*100:.1f}%")
    
    # Count channels with significant differences
    significant_threshold = 0.5  # 50% difference
    significant_channels = sum(1 for diff in diff_percent.values() if diff > significant_threshold)
    print(f"Channels with >50% difference: {significant_channels}/{len(channel_names)}")
    
    return diff_percent

def plot_difference_distribution(diff_percent: Dict[str, float], 
                               save_path: str = None) -> None:
    """
    Plot the distribution of percentage differences.
    
    Parameters
    ----------
    diff_percent : Dict[str, float]
        Dictionary mapping channel names to percentage differences.
    save_path : str, optional
        Path to save the plot.
    """
    differences = list(diff_percent.values())
    
    plt.figure(figsize=(12, 5))
    
    # Histogram of differences
    plt.subplot(1, 2, 1)
    plt.hist(differences, bins=20, alpha=0.7, edgecolor='black')
    plt.xlabel('Percentage Difference')
    plt.ylabel('Number of Channels')
    plt.title('Distribution of Channel Intensity Differences')
    plt.grid(axis='y', alpha=0.3)
    
    # Box plot
    plt.subplot(1, 2, 2)
    plt.boxplot(differences, vert=True)
    plt.ylabel('Percentage Difference')
    plt.title('Box Plot of Percentage Differences')
    plt.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Difference distribution plot saved to {save_path}")
    
    plt.show()

diff_percent = analyze_channel_differences(
    totals_slide1, 
    totals_slide2, 
    n_cells_slide1=426443, 
    n_cells_slide2=476417,
    top_n=10
)

plot_difference_distribution(
    diff_percent,
    save_path='/media/Lynn/pictures/Data_Exploration/channel_difference_distribution.png'
)

In [None]:
def create_comparison_heatmap(totals_slide1: Dict[str, float],
                             totals_slide2: Dict[str, float],
                             n_cells_slide1: int,
                             n_cells_slide2: int,
                             save_path: str = None) -> None:
    """
    Create a heatmap comparing channel differences with and without cell normalization.
    
    Parameters
    ----------
    totals_slide1, totals_slide2 : Dict[str, float]
        Dictionaries mapping channel names to total intensity values.
    n_cells_slide1, n_cells_slide2 : int
        Number of cells in each slide.
    save_path : str, optional
        Path to save the heatmap plot.
    """
    
    channel_names = list(totals_slide1.keys())
    
    # Calculate unnormalized differences (using raw totals)
    unnorm_diffs = []
    for ch in channel_names:
        total1 = totals_slide1[ch]
        total2 = totals_slide2[ch]
        avg_total = (total1 + total2) / 2
        if avg_total > 0:
            diff_percent = abs(total1 - total2) / avg_total * 100
        else:
            diff_percent = 0
        unnorm_diffs.append(diff_percent)
    
    # Calculate cell-normalized differences
    cell_diffs = []
    for ch in channel_names:
        # Normalize by cell count
        norm1 = totals_slide1[ch] / n_cells_slide1
        norm2 = totals_slide2[ch] / n_cells_slide2
        avg_norm = (norm1 + norm2) / 2
        if avg_norm > 0:
            diff_percent = abs(norm1 - norm2) / avg_norm * 100
        else:
            diff_percent = 0
        cell_diffs.append(diff_percent)
    
    # Create heatmap
    plt.figure(figsize=(16, 4))
    data_matrix = np.array([unnorm_diffs, cell_diffs])
    x = np.arange(len(channel_names))
    
    im = plt.imshow(data_matrix, cmap='YlOrRd', aspect='auto')
    
    plt.xticks(x, channel_names, rotation=90)
    plt.yticks([0, 1], ['Unnormalized', 'Cell-normalized'])
    plt.title('Channel Intensity Percentage Differences Between Slides')
    plt.colorbar(im, label='Percentage Difference (%)')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Heatmap saved to {save_path}")
    
    plt.show()
    
    # Print summary comparison
    print("\nComparison Summary:")
    print(f"{'Method':<15} {'Mean Diff':<10} {'Max Diff':<10} {'Min Diff':<10}")
    print("-" * 50)
    print(f"{'Unnormalized':<15} {np.mean(unnorm_diffs):<10.1f} {np.max(unnorm_diffs):<10.1f} {np.min(unnorm_diffs):<10.1f}")
    print(f"{'Cell-normalized':<15} {np.mean(cell_diffs):<10.1f} {np.max(cell_diffs):<10.1f} {np.min(cell_diffs):<10.1f}")
    
    # Show channels with biggest changes after normalization
    normalization_effect = np.array(unnorm_diffs) - np.array(cell_diffs)
    effect_indices = np.argsort(np.abs(normalization_effect))[::-1]
    
    print(f"\nTop 5 channels most affected by cell normalization:")
    print(f"{'Channel':<15} {'Unnorm %':<10} {'Cell-norm %':<12} {'Change':<10}")
    print("-" * 50)
    for i in range(min(5, len(channel_names))):
        idx = effect_indices[i]
        ch = channel_names[idx]
        change = normalization_effect[idx]
        print(f"{ch:<15} {unnorm_diffs[idx]:<10.1f} {cell_diffs[idx]:<12.1f} {change:+.1f}")

create_comparison_heatmap(
    totals_slide1,
    totals_slide2,
    n_cells_slide1=426443,
    n_cells_slide2=476417,
    save_path='/media/Lynn/pictures/Data_Exploration/codex_channel_differences_heatmap.png')

#### Arcsin transform and plot again

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Define cofactors for each CODEX marker
# Based on the paper's guidance and typical marker expression patterns
cofactors = {
    'DAPI': 150,           # Nuclear marker - typically bright
    'FoxP3': 50,           # Transcription factor - moderate expression
    'aSMA': 100,           # Structural protein - variable expression
    'CD4': 75,             # T cell marker - moderate expression
    'CD8': 75,             # T cell marker - moderate expression
    'CD31': 100,           # Endothelial marker - moderate to high
    'CD11c': 75,           # Dendritic cell marker - moderate
    'IFNG': 25,            # Cytokine - typically low expression
    'Pan-Cytokeratin': 150, # Epithelial marker - typically bright
    'CD68': 100,           # Macrophage marker - moderate to high
    'CD20': 100,           # B cell marker - typically bright
    'CD66b': 75,           # Neutrophil marker - moderate
    'TNFa': 25,            # Cytokine - typically low expression
    'CD45RO': 75,          # Memory T cell marker - moderate
    'CD14': 100,           # Monocyte marker - typically bright
    'CD11b': 75,           # Myeloid marker - moderate
    'Vimentin': 150,       # Structural protein - typically bright
    'CD163': 75,           # Macrophage marker - moderate
    'IL10': 25,            # Cytokine - typically low expression
    'CD45': 150,           # Pan-leukocyte marker - typically bright
    'CCR7': 50,            # Chemokine receptor - typically dim
    'CD38': 75,            # Activation marker - moderate
    'CD69': 50,            # Early activation marker - typically dim
    'Podoplanin': 75,      # Lymphatic marker - moderate
    'PNAd': 75,            # HEV marker - moderate
    'CD16': 75,            # NK/neutrophil marker - moderate
    'CXCL13': 50,          # Chemokine - typically dim to moderate
}

def asinh_transform(data, cofactor):
    """
    Apply asinh transformation with specified cofactor
    Formula: asinh(x / cofactor)
    """
    return np.arcsinh(data / cofactor)

def transform_mean_intensities(means_dict, cofactors_dict):
    """
    Transform mean channel intensities using asinh with marker-specific cofactors
    
    Parameters:
    means_dict: dictionary with marker names as keys and mean intensities as values
    cofactors_dict: dictionary with marker names as keys and cofactors as values
    
    Returns:
    transformed_means: dictionary with transformed mean intensities
    """
    transformed_means = {}
    
    for marker, intensity in means_dict.items():
        cofactor = cofactors_dict[marker]
        transformed_means[marker] = asinh_transform(intensity, cofactor)
        print(f"{marker}: cofactor={cofactor}, raw_mean={intensity:.2f}, transformed={transformed_means[marker]:.3f}")

    return transformed_means

def plot_comparison(means_slide1, means_slide2, cofactors_dict, figsize=(15, 10)):
    """
    Plot comparison of transformed mean intensities between two slides
    """
    # Transform the data
    transformed_slide1 = transform_mean_intensities(means_slide1, cofactors_dict)
    transformed_slide2 = transform_mean_intensities(means_slide2, cofactors_dict)
    
    # Create DataFrame for plotting
    markers = list(transformed_slide1.keys())
    df = pd.DataFrame({
        'Marker': markers + markers,
        'Transformed_Intensity': list(transformed_slide1.values()) + list(transformed_slide2.values()),
        'Slide': ['Slide 1'] * len(markers) + ['Slide 2'] * len(markers)
    })
    
    # Create comparison plots
    fig, axes = plt.subplots(2, 2, figsize=figsize)
    
    # 1. Bar plot comparison
    ax1 = axes[0, 0]
    x = np.arange(len(markers))
    width = 0.35
    ax1.bar(x - width/2, list(transformed_slide1.values()), width, label='Slide 1', alpha=0.8)
    ax1.bar(x + width/2, list(transformed_slide2.values()), width, label='Slide 2', alpha=0.8)
    ax1.set_xlabel('Markers')
    ax1.set_ylabel('Transformed Mean Intensity (asinh)')
    ax1.set_title('Transformed Mean Intensities Comparison')
    ax1.set_xticks(x)
    ax1.set_xticklabels(markers, rotation=45, ha='right')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Scatter plot
    ax2 = axes[0, 1]
    ax2.scatter(list(transformed_slide1.values()), list(transformed_slide2.values()), alpha=0.7)
    # Add diagonal line
    min_val = min(min(transformed_slide1.values()), min(transformed_slide2.values()))
    max_val = max(max(transformed_slide1.values()), max(transformed_slide2.values()))
    ax2.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5)
    ax2.set_xlabel('Slide 1 (asinh transformed)')
    ax2.set_ylabel('Slide 2 (asinh transformed)')
    ax2.set_title('Slide 1 vs Slide 2 Correlation')
    ax2.grid(True, alpha=0.3)
    
    # Add marker labels to points
    for i, marker in enumerate(markers):
        ax2.annotate(marker, 
                    (list(transformed_slide1.values())[i], list(transformed_slide2.values())[i]),
                    xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.7)
    
    # 3. Difference plot
    ax3 = axes[1, 0]
    differences = [transformed_slide2[marker] - transformed_slide1[marker] for marker in markers]
    colors = ['red' if diff < 0 else 'blue' for diff in differences]
    ax3.bar(markers, differences, color=colors, alpha=0.7)
    ax3.set_xlabel('Markers')
    ax3.set_ylabel('Difference (Slide 2 - Slide 1)')
    ax3.set_title('Difference in Transformed Intensities')
    ax3.tick_params(axis='x', rotation=45)
    ax3.axhline(y=0, color='black', linestyle='-', alpha=0.5)
    ax3.grid(True, alpha=0.3)
    
    # 4. Heatmap
    ax4 = axes[1, 1]
    heatmap_data = np.array([list(transformed_slide1.values()), list(transformed_slide2.values())])
    im = ax4.imshow(heatmap_data, cmap='viridis', aspect='auto')
    ax4.set_xticks(range(len(markers)))
    ax4.set_xticklabels(markers, rotation=45, ha='right')
    ax4.set_yticks([0, 1])
    ax4.set_yticklabels(['Slide 1', 'Slide 2'])
    ax4.set_title('Transformed Intensities Heatmap')
    plt.colorbar(im, ax=ax4, label='asinh(intensity)')
    
    plt.tight_layout()
    plt.show()
    
    return transformed_slide1, transformed_slide2, df

transformed_slide1, transformed_slide2, comparison_df = plot_comparison(means_slide1, means_slide2, cofactors)

print("CODEX Mean Intensity Transformation Code Loaded!")
print("\nCofactors defined for markers:")
for marker, cofactor in cofactors.items():
    print(f"  {marker}: {cofactor}")

print(f"\nTo use this code:")
print("1. transformed_slide1, transformed_slide2, df = plot_comparison(means_slide1, means_slide2, cofactors)")
print("2. The function will transform and plot comparisons between your slides")
print("3. You can adjust cofactors in the cofactors dictionary if needed")

In [None]:
# How to use the optimize_cofactors_interactive function with your slide data

# Step 1: Get the plotting function
plot_transformation_effects = optimize_cofactors_interactive(raw_data_dict=None)

# Step 2: Your data structure
# totals_slide1: Dict[str, float] - contains marker names and their total intensities
# totals_slide2: Dict[str, float] - contains marker names and their total intensities

# Since you have totals (single values per marker per slide), we'll combine them for analysis
def combine_slide_data(totals_slide1, totals_slide2):
    """Combine data from both slides for cofactor optimization"""
    combined_data = {}
    for marker in totals_slide1.keys():
        if marker in totals_slide2:
            # Create array with both slide values for each marker
            combined_data[marker] = np.array([totals_slide1[marker], totals_slide2[marker]])
    return combined_data

# Combine your slide data
combined_intensities = combine_slide_data(totals_slide1, totals_slide2)

# Step 3a: Optimize cofactors for individual markers
# Since you only have 2 data points per marker, the histograms won't be very informative
# Instead, let's create a more useful approach for your data

def optimize_cofactors_for_slide_data(totals_slide1, totals_slide2, current_cofactors):
    """
    Optimize cofactors specifically for slide comparison data
    Shows how different cofactors affect the transformed values and their differences
    """
    markers = list(totals_slide1.keys())
    
    for marker in markers:
        if marker in totals_slide2:
            slide1_val = totals_slide1[marker]
            slide2_val = totals_slide2[marker]
            
            print(f"\n=== {marker} ===")
            print(f"Raw values - Slide 1: {slide1_val:.1f}, Slide 2: {slide2_val:.1f}")
            print(f"Raw difference: {slide2_val - slide1_val:.1f}")
            
            # Test different cofactors
            cofactor_range = [25, 50, 75, 100, 150, 200, 250]
            
            print("Cofactor | Slide1_transformed | Slide2_transformed | Difference | Fold_change")
            print("-" * 80)
            
            for cofactor in cofactor_range:
                trans1 = asinh_transform(slide1_val, cofactor)
                trans2 = asinh_transform(slide2_val, cofactor)
                diff = trans2 - trans1
                fold_change = trans2 / trans1 if trans1 > 0 else float('inf')
                
                current_marker = "👈 CURRENT" if cofactor == current_cofactors.get(marker, 50) else ""
                print(f"{cofactor:8d} | {trans1:17.3f} | {trans2:17.3f} | {diff:9.3f} | {fold_change:10.3f} {current_marker}")

# Use the optimization function
optimize_cofactors_for_slide_data(totals_slide1, totals_slide2, cofactors)

# Step 3b: Visual comparison of cofactor effects (optional)
def plot_cofactor_effects_for_slides(totals_slide1, totals_slide2, marker_name, cofactor_range=None):
    """
    Create visual comparison of how different cofactors affect your slide data
    """
    if cofactor_range is None:
        cofactor_range = [25, 50, 75, 100, 150, 200]
    
    if marker_name not in totals_slide1 or marker_name not in totals_slide2:
        print(f"Marker {marker_name} not found in both slides")
        return
    
    slide1_val = totals_slide1[marker_name]
    slide2_val = totals_slide2[marker_name]
    
    fig, axes = plt.subplots(2, 1, figsize=(12, 8))
    
    # Plot transformed values
    transformed_slide1 = [asinh_transform(slide1_val, cf) for cf in cofactor_range]
    transformed_slide2 = [asinh_transform(slide2_val, cf) for cf in cofactor_range]
    
    axes[0].plot(cofactor_range, transformed_slide1, 'o-', label='Slide 1', linewidth=2, markersize=8)
    axes[0].plot(cofactor_range, transformed_slide2, 's-', label='Slide 2', linewidth=2, markersize=8)
    axes[0].set_xlabel('Cofactor')
    axes[0].set_ylabel('Transformed Intensity')
    axes[0].set_title(f'{marker_name} - Transformed Values vs Cofactor')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Plot differences
    differences = [t2 - t1 for t1, t2 in zip(transformed_slide1, transformed_slide2)]
    axes[1].plot(cofactor_range, differences, 'ro-', linewidth=2, markersize=8)
    axes[1].set_xlabel('Cofactor')
    axes[1].set_ylabel('Difference (Slide 2 - Slide 1)')
    axes[1].set_title(f'{marker_name} - Difference vs Cofactor')
    axes[1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return transformed_slide1, transformed_slide2, differences

# Example usage for specific markers:
plot_cofactor_effects_for_slides(totals_slide1, totals_slide2, 'CD4')
plot_cofactor_effects_for_slides(totals_slide1, totals_slide2, 'IFNG', cofactor_range=[10, 25, 50, 75, 100])

# Step 4: Complete analysis for all markers
def analyze_all_markers(totals_slide1, totals_slide2, current_cofactors):
    """
    Comprehensive analysis of all markers with current cofactors
    """
    print("=== CURRENT COFACTOR ANALYSIS ===")
    print(f"{'Marker':<15} | {'Raw_S1':<8} | {'Raw_S2':<8} | {'Trans_S1':<9} | {'Trans_S2':<9} | {'Difference':<10}")
    print("-" * 80)
    
    results = {}
    for marker in totals_slide1.keys():
        if marker in totals_slide2:
            raw1 = totals_slide1[marker]
            raw2 = totals_slide2[marker]
            cofactor = current_cofactors.get(marker, 50)
            
            trans1 = asinh_transform(raw1, cofactor)
            trans2 = asinh_transform(raw2, cofactor)
            diff = trans2 - trans1
            
            results[marker] = {
                'raw_slide1': raw1,
                'raw_slide2': raw2,  
                'transformed_slide1': trans1,
                'transformed_slide2': trans2,
                'difference': diff,
                'cofactor': cofactor
            }
            
            print(f"{marker:<15} | {raw1:<8.1f} | {raw2:<8.1f} | {trans1:<9.3f} | {trans2:<9.3f} | {diff:<10.3f}")
    
    return results

# Run complete analysis
analysis_results = analyze_all_markers(totals_slide1, totals_slide2, cofactors)

# How to interpret the results:
print("""
How to interpret the cofactor optimization plots:

1. Look for a cofactor that spreads your data well across the transformed range
2. Avoid cofactors that compress all your data into a narrow range
3. Avoid cofactors that don't transform the data enough (looks too similar to raw)
4. The ideal cofactor should:
   - Separate positive and negative populations clearly
   - Not over-compress high values
   - Maintain biological signal differences

5. General guidelines:
   - Bright markers (DAPI, Pan-CK): cofactors 100-200
   - Moderate markers (CD markers): cofactors 50-100  
   - Dim markers (cytokines): cofactors 10-50
""")

# Alternative: If you want to test a single marker with many cofactors
def test_single_marker_cofactors(marker_name, intensities, cofactor_list):
    """Test multiple cofactors for a single marker"""
    fig, axes = plt.subplots(2, len(cofactor_list)//2 + len(cofactor_list)%2, 
                            figsize=(4*len(cofactor_list), 8))
    axes = axes.flatten() if len(cofactor_list) > 2 else [axes] if len(cofactor_list) == 1 else axes
    
    for i, cofactor in enumerate(cofactor_list):
        if i < len(axes):
            transformed = asinh_transform(intensities, cofactor)
            axes[i].hist(transformed, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
            axes[i].set_title(f'Cofactor: {cofactor}')
            axes[i].set_xlabel('asinh(intensity)')
            axes[i].set_ylabel('Frequency')
            axes[i].grid(True, alpha=0.3)
    
    # Hide empty subplots
    for i in range(len(cofactor_list), len(axes)):
        axes[i].set_visible(False)
    
    plt.suptitle(f'{marker_name} - Cofactor Optimization', fontsize=16)
    plt.tight_layout()
    plt.show()

# Example usage of the alternative function:
#test_cofactor_list = [25, 50, 75, 100, 125, 150, 175, 200]
#test_single_marker_cofactors('CD4', cd4_intensities, test_cofactor_list)

### QC Tables

#### Xenium

##### before filtering

In [None]:
slide1_summary = qc.compute_xenium_metrics(merged_table_s1, merged_xenium_s1)
slide2_summary = qc.compute_xenium_metrics(merged_table_s2, merged_xenium_s2)

df = pd.DataFrame([slide1_summary, slide2_summary], index=["Xenium Slide 1", "Xenium Slide 2"])
print(df)

| Metric                                                | Xenium Slide 1   | Xenium Slide 2   |
|-------------------------------------------------------|------------------|------------------|
| Total number of cells                                 | 457,439          | 565,266          |
| Total number of transcripts                           | 31,327,513       | 35,161,247       |
| Average genes per cell                                | 26.68            | 22.79            |
| % of transcripts **inside** cells                     | 88.57%           | 85.10%           |
| % of transcripts **outside** cells                    | 11.43%           | 14.90%           |
| Total number of spots                                 | 32               | 32               |
| Number of spots detected                              | 27               | 31               |
| Number of different patients                          | 8                | 6                |
| # Colon tissue samples at Timepoint 1                 | 6                | 6                |
| # Colon tissue samples at Timepoint 2                 | 0                | 0                |
| # Ileum tissue samples at Timepoint 1                 | 4                | 18               |
| # Ileum tissue samples at Timepoint 2                 | 17               | 7                |
 
*Only the detected spots are counted in the # rows*                     

*CODEX and Xenium detect the same spots*                                

Timepoint 1 -> Before Treatment

Timepoint 2 -> During Treatment

##### After filtering

In [None]:
sc.pp.filter_cells(merged_table_s1, min_counts=3)
sc.pp.filter_genes(merged_table_s1, min_cells=3)

In [None]:
sc.pp.filter_cells(merged_table_s2, min_counts=3)
sc.pp.filter_genes(merged_table_s2, min_cells=3)

In [None]:
sc.pp.filter_cells(adata_xenium_both_slides, min_counts=3)
sc.pp.filter_genes(adata_xenium_both_slides, min_cells=3)

In [None]:
sc.pp.calculate_qc_metrics(merged_table_s1, percent_top=(10, 20, 50, 150), inplace=True)
sc.pp.calculate_qc_metrics(merged_table_s2, percent_top=(10, 20, 50, 150), inplace=True)

In [None]:
sc.pp.calculate_qc_metrics(adata_xenium_both_slides, percent_top=(10, 20, 50, 150), inplace=True)

In [None]:
slide1_summary = qc.compute_xenium_metrics(merged_table_s1, merged_xenium_s1)
slide2_summary = qc.compute_xenium_metrics(merged_table_s2, merged_xenium_s2)

df = pd.DataFrame([slide1_summary, slide2_summary], index=["Xenium Slide 1", "Xenium Slide 2"])
print(df)

In [None]:
total_summary = qc.compute_xenium_metrics(adata_xenium_both_slides, merged_xenium_s1)

df = pd.DataFrame(total_summary, index=["Total"])
print(df)

In [None]:
cprobes = (
    adata_xenium_both_slides.obs["control_probe_counts"].sum() / adata_xenium_both_slides.obs["total_counts"].sum() * 100
)
cwords = (
    adata_xenium_both_slides.obs["control_codeword_counts"].sum() / adata_xenium_both_slides.obs["total_counts"].sum() * 100
)
print(f"Negative DNA probe count % total: {cprobes}")
print(f"Negative decoding count % total: {cwords}")

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(15, 4))

axs[0].set_title("Total transcripts per cell")
sns.histplot(
    merged_table_s1.obs["total_counts"],
    kde=False,
    ax=axs[0],
)

axs[1].set_title("Unique transcripts per cell")
sns.histplot(
    merged_table_s1.obs["n_genes_by_counts"],
    kde=False,
    ax=axs[1],
)


axs[2].set_title("Area of segmented cells")
sns.histplot(
    merged_table_s1.obs["cell_area"],
    kde=False,
    ax=axs[2],
)

axs[3].set_title("Nucleus ratio")
sns.histplot(
    merged_table_s1.obs["nucleus_area"] / merged_table_s1.obs["cell_area"],
    kde=False,
    ax=axs[3],
)

In [None]:
merged_table_s1.obs["cell_area"].mean()

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(15, 4))

axs[0].set_title("Total transcripts per cell s2")
sns.histplot(
    merged_table_s2.obs["total_counts"],
    kde=False,
    ax=axs[0],
)

axs[1].set_title("Unique transcripts per cell s2")
sns.histplot(
    merged_table_s2.obs["n_genes_by_counts"],
    kde=False,
    ax=axs[1],
)


axs[2].set_title("Area of segmented cells s2")
sns.histplot(
    merged_table_s2.obs["cell_area"],
    kde=False,
    ax=axs[2],
)

axs[3].set_title("Nucleus ratio s2")
sns.histplot(
    merged_table_s2.obs["nucleus_area"] / merged_table_s2.obs["cell_area"],
    kde=False,
    ax=axs[3],
)

In [None]:
merged_table_s2.obs["cell_area"].mean()

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(15, 4))

axs[0].set_title("Total transcripts per cell")
sns.histplot(
    adata_xenium_both_slides.obs["total_counts"],
    kde=False,
    ax=axs[0],
)

axs[1].set_title("Unique transcripts per cell")
sns.histplot(
    adata_xenium_both_slides.obs["n_genes_by_counts"],
    kde=False,
    ax=axs[1],
)


axs[2].set_title("Area of segmented cells")
sns.histplot(
    adata_xenium_both_slides.obs["cell_area"],
    kde=False,
    ax=axs[2],
)

axs[3].set_title("Nucleus ratio")
sns.histplot(
    adata_xenium_both_slides.obs["nucleus_area"] / adata_xenium_both_slides.obs["cell_area"],
    kde=False,
    ax=axs[3],
)

In [None]:
adata_xenium_both_slides.obs["cell_area"].mean()

#### Xenium with Metadata

In [None]:
adata_xenium_both_slides = ad.read_h5ad("/media/Lynn/data/Xenium_table_with_metadata/adata_both_slides.h5ad")

In [None]:
adata_xenium_both_slides

In [None]:
qc_metrics= sc.pp.calculate_qc_metrics(adata_xenium_both_slides, percent_top=(10, 20, 50, 150), inplace=True)

In [None]:
cprobes = (
    adata_xenium_both_slides.obs["control_probe_counts"].sum() / adata_xenium_both_slides.obs["total_counts"].sum() * 100
)
cwords = (
    adata_xenium_both_slides.obs["control_codeword_counts"].sum() / adata_xenium_both_slides.obs["total_counts"].sum() * 100
)
print(f"Negative DNA probe count % : {cprobes}")
print(f"Negative decoding count % : {cwords}")

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(15, 4))

axs[0].set_title("Total transcripts per cell")
sns.histplot(
    adata_xenium_both_slides.obs["total_counts"],
    kde=False,
    ax=axs[0],
)

axs[1].set_title("Unique transcripts per cell")
sns.histplot(
    adata_xenium_both_slides.obs["n_genes_by_counts"],
    kde=False,
    ax=axs[1],
)


axs[2].set_title("Area of segmented cells")
sns.histplot(
    adata_xenium_both_slides.obs["cell_area"],
    kde=False,
    ax=axs[2],
)

axs[3].set_title("Nucleus ratio")
sns.histplot(
    adata_xenium_both_slides.obs["nucleus_area"] / adata_xenium_both_slides.obs["cell_area"],
    kde=False,
    ax=axs[3],
)

In [None]:
sc.pp.filter_cells(adata_xenium_both_slides, min_counts=5)
sc.pp.filter_genes(adata_xenium_both_slides, min_cells=5)

In [None]:
# Keep only rows with a non-null and non-empty core_id
adata_filtered = adata_xenium_both_slides[adata_xenium_both_slides.obs["core_ID"].notnull() & (adata_xenium_both_slides.obs["core_ID"] != "Unknown")].copy()

# Check how many rows remain
print(adata_filtered.shape)
print(adata_filtered.obs["core_ID"].unique())

In [None]:
adata_filtered.layers["counts"] = adata_filtered.X.copy()
sc.pp.normalize_total(adata_filtered, target_sum = 100, inplace=True)
sc.pp.log1p(adata_filtered)
sc.pp.pca(adata_filtered)
sc.pp.neighbors(adata_filtered)
sc.tl.umap(adata_filtered)

In [None]:
sc.tl.leiden(adata_filtered)

In [None]:
adata_filtered.write("/media/Lynn/data/Xenium_table_with_metadata/adata_both_slides_filtered_and_leiden.h5ad")

In [None]:
sc.pl.umap(
    adata_filtered,
    color=[
        "total_counts",
        "n_genes_by_counts",
        "leiden",
    ],
    wspace=0.4,
    save = '_xenium_both_slides_with_metadata_res_1.png'
)

In [None]:
# Calculate the number of cells per cluster
cluster_counts = adata_filtered.obs['leiden'].value_counts()

# Create a new column to store the count of cells for each cluster
adata_filtered.obs['cluster_cell_count'] = adata_filtered.obs['leiden'].map(cluster_counts)

# Plot UMAP with a gradient based on the number of cells per cluster
sc.pl.umap(
    adata_filtered,
    color='cluster_cell_count',  # Use the newly created column to represent the number of cells
    wspace=0.4,
    save = '_xenium_both_slides_with_metadata_res_1_cell_count.png'
)

In [None]:
adata_slide_1 = adata_filtered[adata_filtered.obs['slide_ID']==22110].copy()
adata_slide_2 = adata_filtered[adata_filtered.obs['slide_ID']==22111].copy()

In [None]:
sq.pl.spatial_scatter(
    adata_slide_1,
    library_id="spatial",
    shape=None,
    color=[
        "leiden",
    ],
    wspace=0.4,
    save = '_spatial_xenium_22110_with_metadata_res_1.png'
)

In [None]:
sq.pl.spatial_scatter(
    adata_slide_2,
    library_id="spatial",
    shape=None,
    color=[
        "leiden",
    ],
    wspace=0.4,
    save = '_spatial_xenium_22111_with_metadata_res_1.png'
)

In [None]:
adata_s3_c2 = adata_slide_2[(adata_slide_2.obs['core_ID']=='X3Y6')].copy()
adata_s3_c2

In [None]:
sq.pl.spatial_scatter(
    adata_s3_c2,
    library_id="spatial",
    shape=None,
    color=[
        "leiden",
    ],
    wspace=0.4,
    save = '_spatial__xenium_22111_X3Y6_with_metadata_res_1.png'
)

In [None]:
adata_s2_c4 = adata_slide_2[(adata_slide_2.obs['core_ID']=='X4Y2')].copy()

sq.pl.spatial_scatter(
    adata_s2_c4,
    library_id="spatial",
    shape=None,
    color=[
        "leiden",
    ],
    wspace=0.4,
    save = '_spatial__xenium_22111_X4Y2_with_metadata_res_1.png'
)

In [None]:
adata_s1_c1 = adata_slide_1[(adata_slide_1.obs['core_ID']=='X1Y1')].copy()

sq.pl.spatial_scatter(
    adata_s1_c1,
    library_id="spatial",
    shape=None,
    color=[
        "leiden",
    ],
    wspace=0.4,
    save = '_spatial__xenium_22110_X1Y1_with_metadata_res_1.png'
)

In [None]:
sc.pl.umap(
    adata_filtered,
    color=[
        "tissue",
    ],
    wspace=0.4,
    save = '_by_tissue_xenium_with_metadata_res_1.png'
)

In [None]:
sc.pl.umap(
    adata_filtered,
    color=[
        "time_point",
    ],
    wspace=0.4,
    save = '_by_timepoint_xenium_with_metadata_res_1.png'
)

In [None]:
adata_filtered.obs["patient_ID"] = adata_filtered.obs["patient_ID"].astype(str)  

sc.pl.umap(
    adata_filtered,
    color=["patient_ID"],
    wspace=0.4,
    save='_by_patient_xenium_with_metadata_res_1.png'
)

In [None]:
sc.pl.umap(
    adata_filtered,
    color=[
        "batch",
    ],
    wspace=0.4,
    save = '_by_batch_xenium_with_metadata_res_1.png'
)

In [None]:
adata_filtered.obs["year"] = adata_filtered.obs["year"].astype(str)  

sc.pl.umap(
    adata_filtered,
    color=["year"],
    wspace=0.4,
    save='_by_year_xenium_with_metadata_res_1.png'
)

In [None]:
adata_ileum = adata_filtered[adata_filtered.obs['tissue']=='ileum'].copy()
adata_colon = adata_filtered[adata_filtered.obs['tissue']=='colon'].copy()

In [None]:
sc.pl.umap(
    adata_ileum,
    color=[
        "core_ID",
    ],
    wspace=0.4,
    save = '_ileum_by_core_xenium_with_metadata_res_1.png'
)

In [None]:
sc.pl.umap(
    adata_colon,
    color=[
        "core_ID",
    ],
    wspace=0.4,
    save = '_colon_by_core_xenium_with_metadata_res_1.png'
)

In [None]:
# Total transcripts assigned to cells
assigned = adata_filtered.obs["transcript_counts"].sum()

# Total unassigned transcripts (outside cells)
unassigned = adata_filtered.obs["unassigned_codeword_counts"].sum()

# Total transcripts detected
total = assigned + unassigned

# Percent of transcripts outside cells
pct_outside = (unassigned / total) * 100

print(f"Percentage of transcripts outside cells: {pct_outside:.2f}%")

In [None]:
def qc_adata(adata):
    summary = {}
    
    # Total number of cell summary
    summary['total_cells'] = adata.n_obs
        
    # Average genes per cell 
    summary['average_unique_genes_per_cell'] = adata.obs['n_genes_by_counts'].mean()
    
    # Average transcripts per cell
    summary['mean_transcripts_per_cell'] = adata.obs["transcript_counts"].mean()
    
    # Total number of transcripts 
    summary['total_transcripts_in_cells'] = adata.obs['transcript_counts'].sum()
    
    return summary

In [None]:
qc_slide_1 = qc_adata(adata_slide_1)
qc_slide_2 = qc_adata(adata_slide_2)
qc_total = qc_adata(adata_filtered)

In [None]:
qc_slide_1

In [None]:
qc_slide_2

In [None]:
qc_total

In [None]:
adata_subsample = sc.pp.subsample(adata_filtered, fraction=0.1, copy=True)

sq.gr.co_occurrence(
    adata_subsample,
    cluster_key="leiden",
)

In [None]:
sq.pl.co_occurrence(
    adata_subsample,
    cluster_key="leiden",
    figsize=(10, 10),
    clusters = '12',
    save = '_cooccurence_xenium_with_metadata_res_1_subset_0.1.png'
)

In [None]:
cprobes = (
    adata_slide_1.obs["control_probe_counts"].sum() / adata_slide_1.obs["total_counts"].sum() * 100
)
cwords = (
    adata_slide_1.obs["control_codeword_counts"].sum() / adata_slide_1.obs["total_counts"].sum() * 100
)
print(f"Negative DNA probe count % slide 1: {cprobes}")
print(f"Negative decoding count % slide 1: {cwords}")

In [None]:
cprobes = (
    adata_slide_2.obs["control_probe_counts"].sum() / adata_slide_2.obs["total_counts"].sum() * 100
)
cwords = (
    adata_slide_2.obs["control_codeword_counts"].sum() / adata_slide_2.obs["total_counts"].sum() * 100
)
print(f"Negative DNA probe count % slide 2: {cprobes}")
print(f"Negative decoding count % slide 2: {cwords}")

In [None]:
cprobes = (
    adata_filtered.obs["control_probe_counts"].sum() / adata_filtered.obs["total_counts"].sum() * 100
)
cwords = (
    adata_filtered.obs["control_codeword_counts"].sum() / adata_filtered.obs["total_counts"].sum() * 100
)
print(f"Negative DNA probe count % total: {cprobes}")
print(f"Negative decoding count % total: {cwords}")

#### Annotation

In [None]:
import pandas as pd

# Load file
marker_df = pd.read_csv('/media/Lisa/Projects/IBD001/Xenium_analysis/Squidpy/final_combined_genes.csv')
marker_df

In [None]:
lisa_annotation = pd.read_csv('/media/Lisa/Projects/IBD001/Xenium_analysis/Squidpy/cell_id_to_celltype.csv')

In [None]:
lisa_annotation

In [None]:
adata_with_lisa_annotation = adata_filtered.copy()
# First, make sure 'cell_id' is set as index in both to facilitate the join
adata_with_lisa_annotation.obs = adata_with_lisa_annotation.obs.set_index('cell_id')
lisa_annotation = lisa_annotation.set_index('cell_id')

# Now join lisa_annotation 'group' column into adata.obs
adata_with_lisa_annotation.obs = adata_with_lisa_annotation.obs.join(lisa_annotation['group'], how='left')

# Optionally, reset index if you want 'cell_id' back as a column
adata_with_lisa_annotation.obs = adata_with_lisa_annotation.obs.reset_index()

In [None]:
lisa_annotation

In [None]:
sc.pl.umap(
    adata_with_lisa_annotation,
    color=[
        "group",
    ],
    wspace=0.4,
    save = '_xenium_with_lisa_annotation.png'
)