# Patches with Simulation Data

#### Description

The aim here is to apply Patches to the simulated data (here: t100,s80,b0) in its interpretability mode to get interpretable scores. These scores will then be compared to the gene ground truths (i.e., is a gene a type or state gene, or both) and to the typeness and stateness scores from the FSS paper.

Given gene expression profiles and condition labels, Patches produces two latent representations of each cell, a common representation (similar across all conditions) and a condition-specific representation (similar inside of each condition). From these representations, a linear decoder can reconstruct gene expression profiles while providing interpretable attribute scores and a common score for each gene via  its weights. Importantly, one common score is returned as the sum of the weights corresponding to the dimensions of the common latent representation, and one attribute score is returned for each attribute of all conditions as the sum of the weights corresponding to the dimensions of each attribute in the condition-specific latent representation.

Due to that, a cell type score can only be obtained if the cell types are input as condition to the model, i.e., they must be know (or obtained by clustering + marker gene analysis). A single cell type or state score could potentially be obtained by combining the attribute scores of all type or state attributes (e.g., sum of absolutes or sum of squares), respectively. The common score might contain cell type information if the cell type is not given as condition to the model, but this is to be explored.

Thus, Patches is run with both only state as condition and type + state as conditions, and the scores are further explored as described above. The workflow is adapted from this tutorial: https://ladder.readthedocs.io/en/latest/notebooks/interpretable_vu_tutorial.html


#### Possible Improvements:

- data preparation:
    - run with all genes or run with HVGs (get scores for all genes vs. better performance with top 1500 genes)
    - ~~batch effect correction~~ *does not seem to have a noticable influence here*

- `workflow.prep_model`:
    - ~~run with or without `batch_key` (here: `batch_key = sample_id`)~~ *does not make noticable difference and makes sense to do batch correction*
    - ~~use `cell_type_label_key = cluster_id` (clusters correspond to cell type in simulation)~~ *only used for subsetting evaluation for a cell type*
    - use different `minibatch_size` (default is 128, maybe try smaller size with ~2000 cells)
    - ~~tune `model_args` and `optim_args`~~ *maybe consider lr otherwise not necessary*

- `workflow.run_model`:
    - run with different number of `max_epochs` (default is 1500)
    - try different `convergence_threshold` and `convergence_window`

## Import Packages

In [None]:
import scanpy as sc
import numpy as np 
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from ladder.scripts import InterpretableWorkflow
from umap import UMAP

## Helper Functions

In [None]:
def force_aspect(ax, aspect=1):
    """
    Force aspect ratio of a matplotlib axis to be equal. Helper function for plotting.
    Taken from Patches tutorial docs.

    Parameters
    ----------
    ax : matplotlib axis
        The axis to set the aspect ratio for.
    aspect : float, optional
        The aspect ratio to set. Default is 1 (equal).
    """
    im = ax.get_images()
    extent =  im[0].get_extent()
    ax.set_aspect(abs((extent[1]-extent[0])/(extent[3]-extent[2]))/aspect)

def create_umap_df(workflow, model):
    """
    Create a dataframe with UMAP and PCA reductions for plotting.
    Adapted from Patches tutorial docs.

    Parameters
    ----------
    workflow : InterpretableWorkflow
        The fitted InterpretableWorkflow object.
    model : str
        The model type, either "Patches" or "Base".
    
    Returns
    -------
    pd.DataFrame
        A dataframe with UMAP and PCA reductions for plotting.
    """
    # UMAP reducers
    reducer_base = UMAP(n_neighbors=50, min_dist=0.1, metric="correlation", verbose=False, random_state=42)
    reducer = UMAP(n_neighbors=50, min_dist=0.1, metric="correlation", verbose=False, random_state=42)

    match model:
        case "Patches":
            # all data reductions needed for the plots
            base_umap = reducer_base.fit_transform(np.array(workflow.anndata.X))
            z_umap = reducer.fit_transform(workflow.anndata.obsm['patches_z_latent']) # Z
            w_umap = reducer.fit_transform(workflow.anndata.obsm['patches_w_latent']) # W
            
            w_pca = sc.pp.pca(workflow.anndata.obsm['patches_w_latent'], random_state=42)[:,:2] # W PCA, grab first 2 PCs
            z_pca = sc.pp.pca(workflow.anndata.obsm['patches_z_latent'], random_state=42)[:,:2] # Z PCA, grab first 2 PCs

            # create dataframe and add all reductions
            df = pd.DataFrame(base_umap)
            df.index = workflow.anndata.obs.index
            
            df.columns = ["base_1", "base_2"]
            df["z_umap_1"], df["z_umap_2"] = z_umap[:,0], z_umap[:,1]
            df["w_umap_1"], df["w_umap_2"] = w_umap[:,0], w_umap[:,1]
            df["z_pc_1"], df["z_pc_2"] = z_pca[:,0], z_pca[:,1]
            df["w_pc_1"], df["w_pc_2"] = w_pca[:,0], w_pca[:,1]

        case "Base":
            anndata = workflow.anndata.copy()
            anndata.X = anndata.layers["counts"]
            
            sc.pp.normalize_total(anndata, target_sum=1e4)
            sc.pp.log1p(anndata)
            sc.tl.pca(anndata, svd_solver="arpack")
            
            base_umap = reducer_base.fit_transform(np.array(anndata.X))
            base_pca = anndata.obsm['X_pca'][:,:2]

            df = pd.DataFrame(base_umap)
            df.index = anndata.obs.index

            df.columns = ["base_umap_1", "base_umap_2"]
            df["base_pc_1"], df["base_pc_2"] = base_pca[:,0], base_pca[:,1]

    # add metadata
    df["group_id"], df["cluster_id"], df["sample_id"] = workflow.anndata.obs["group_id"], workflow.anndata.obs["cluster_id"], workflow.anndata.obs["sample_id"]

    return df

## Load and Prepare Data

The reconstruction performance seems to be influenced a lot by the genes used: a significant increase in performance is observed when using the 1500 highest variable genes instead of all 1979 genes. A further increase in performance with harsher filtering (e.g., top 100 highly variable genes) is not visible. While achiving a lower performance, using all genes gives a larger group to compare to the FSS scores (here, all genes are used). Due to that, workflows for both HVGs and all genes are run.

In [None]:
# load data
adata = sc.read_h5ad("../../data/sim/01-pro/t100,s80,b0.h5ad")

# full adata
adata_full = adata.copy()

# highly variable gene selection on log-normalized data
adata.X = adata.layers["logcounts"]
sc.pp.highly_variable_genes(adata, n_top_genes=1500)
sc.pl.highly_variable_genes(adata)
adata_hvg = adata[:, adata.var["highly_variable"]].copy()

# reset to raw counts for model input (as stated in docs)
adata_full.X = adata_full.layers["counts"]
adata_hvg.X = adata_hvg.layers["counts"]

print(f"Full data shape: {adata_full.shape}, HVG data shape: {adata_hvg.shape}")

## Run Patches in Interpretable Workflow

### Hyperparameters

The main difference in performance (especially for the scores in downstream analyses) seems to be controlled by the number of epochs. Even though the loss does not decrease much after around 500 epochs, downstream performance of Patches scores is significantly improved when training with 1500 epochs compared to training with 500 epochs. Learning rate and batch size do not seem to make a sigificant difference in performance.

In [None]:
# hyperparameters for full data
lr_full = 0.0098
epochs_full = 1500
minibatch_size_full = 32

# hyperparameters for hvg data
lr_hvg = 0.0035
epochs_hvg = 1500
minibatch_size_hvg = 32

# hyperparameters for both
convergence_threshold = 1e-4
convergence_window = 50
random_seed = 42
batch_key = "sample_id"

### Full Data + Condtion (State) Only - Workflow1

In [None]:
# initialize workflow object
workflow1 = InterpretableWorkflow(
    adata_full.copy(),
    verbose=True,
    random_seed=random_seed
)

# define the condition classes & batch key to prepare the data
factors = ["group_id"]
workflow1.prep_model(
    factors,
    batch_key=batch_key,
    minibatch_size=minibatch_size_full,
    model_type="Patches",
    model_args={"ld_normalize" : True},
    optim_args={"lr" : lr_full})

# run model
workflow1.run_model(max_epochs=epochs_full, convergence_threshold=convergence_threshold, convergence_window=convergence_window)
workflow1.save_model("../../data/sim/02-scores_patches/t100,s80,b0-patches_con_full")

In [None]:
# inspect loss
workflow1.plot_loss()

In [None]:
# save embeddings to adata object
workflow1.write_embeddings()
workflow1.anndata.obsm

In [None]:
# evaluate run
workflow1.evaluate_reconstruction()

In [None]:
# get loadings (gene scores)
workflow1.get_conditional_loadings()
workflow1.get_common_loadings()
workflow1.anndata.var

In [None]:
# save loadings to file
workflow1.anndata.var.loc[:, ["Condition1_score_Patches", "Condition2_score_Patches", "common_score_Patches"]].to_csv(
    "../../data/sim/02-scores_patches/t100,s80,b0-patches_con_full_loadings.csv"
)

In [None]:
# Figure skeleton (adapted from Patches tutorial docs)


## create dataframes for plotting
df_patches = create_umap_df(workflow1, "Patches")
df_base = create_umap_df(workflow1, "Base")


## color palettes
klee_palette = [
    "#8B1E3F",  # Deep Burgundy
    "#3B5998",  # Rich Blue
    "#F4A261",  # Warm Orange
    "#264653",  # Deep Teal
    "#E9C46A",  # Soft Yellow
    "#2A9D8F",  # Muted Green
    "#E76F51",  # Burnt Sienna
    "#D3D9E3",  # Soft Pastel Blue
    "#A8DADC",  # Pale Turquoise
    "#BC4749",  # Warm Cranberry Red
]

klee_palette_masch = [
    "#3B5998",  # Rich Blue
    "#6A994E",  # Fresh Olive Green
    "#F4A261",  # Warm Orange
    "#E9C46A",  # Soft Yellow
    "#2A9D8F",  # Muted Green
    "#E76F51",  # Burnt Sienna
    "#FFC8A2",  # Soft Peach
    "#A8DADC",  # Pale Turquoise
    "#BC4749",  # Warm Cranberry Red
]


## plot parameters
fontsize=14
alpha=0.3
s=10
s_pca=3


## create a figure with a 2x2 grid of subplots
fig = plt.figure(figsize=(21, 21))

## define a GridSpec with a 2x2 layout
gs = gridspec.GridSpec(2, 2, wspace=0.17, hspace = 0.3, figure=fig)

## create subplots for the 2x2 grid
ax = [fig.add_subplot(gs[i//2, i%2]) for i in range(4)]

for subax in ax:
    subax.axis('off')

## define a new GridSpec for axis to split vertically
gs_inner_topleft = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0, 0], wspace=0.1, hspace=0.15)
gs_inner_topright = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0, 1], wspace=0.1, hspace=0.15)
gs_inner_botleft = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[1, 0], wspace=0.1, hspace=0.15)
gs_inner_botright = gridspec.GridSpecFromSubplotSpec(2, 4, subplot_spec=gs[1, 1], wspace=0.25)

## create subplots for the inner grid
ax_inner_topleft = [fig.add_subplot(gs_inner_topleft[i//2, i%2]) for i in range(4)]
ax_inner_topright = [fig.add_subplot(gs_inner_topright[i//2, i%2]) for i in range(4)]
ax_inner_botleft = [fig.add_subplot(gs_inner_botleft[i//2, i%2]) for i in range(4)]

## specific for botright
ax_inner_botright = [fig.add_subplot(gs_inner_botright[0,0])]
ax_inner_botright = ax_inner_botright \
+ [
    fig.add_subplot(gs_inner_botright[0,1],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[0,2],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[0,3],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,0],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,1],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,2],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,3],sharey=ax_inner_botright[0]),
]


## UMAP plots

### counts
clu = sns.scatterplot(df_base, x='base_umap_1', y='base_umap_2', ax=ax_inner_topright[0], hue = 'group_id', palette=sns.color_palette(klee_palette[0:2]), s=s, alpha=alpha)
con = sns.scatterplot(df_base, x='base_umap_1', y='base_umap_2', ax=ax_inner_topleft[0], hue = 'cluster_id', palette=sns.color_palette(klee_palette[3:6]), s=s, alpha=alpha)
sam = sns.scatterplot(df_base, x='base_umap_1', y='base_umap_2', ax=ax_inner_botleft[0], hue = 'sample_id', palette=sns.color_palette(klee_palette[0:6]), s=s, alpha=alpha)

### Zs
sns.scatterplot(df_patches, x='z_umap_1', y='z_umap_2', ax=ax_inner_topright[2], hue = 'group_id', palette=sns.color_palette(klee_palette[0:2]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='z_umap_1', y='z_umap_2', ax=ax_inner_topleft[2], hue = 'cluster_id', palette=sns.color_palette(klee_palette[3:6]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='z_umap_1', y='z_umap_2', ax=ax_inner_botleft[2], hue = 'sample_id', palette=sns.color_palette(klee_palette[0:6]), s=s, alpha=alpha, legend=False)

### Ws
sns.scatterplot(df_patches, x='w_umap_1', y='w_umap_2', ax=ax_inner_topright[3], hue = 'group_id', palette=sns.color_palette(klee_palette[0:2]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='w_umap_1', y='w_umap_2', ax=ax_inner_topleft[3], hue = 'cluster_id', palette=sns.color_palette(klee_palette[3:6]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='w_umap_1', y='w_umap_2', ax=ax_inner_botleft[3], hue = 'sample_id', palette=sns.color_palette(klee_palette[0:6]), s=s, alpha=alpha, legend=False)


## PCA plots

### cluster
sns.stripplot(df_patches, y = "z_pc_1", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[0], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[0].axvline(zorder=2, color='black', linestyle = 'dashed')
    
sns.stripplot(df_patches, y = "z_pc_2", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[1], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[1].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_1", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[2], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[2].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_2", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[3], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[3].axvline(zorder=2, color='black', linestyle = 'dashed')


### condition
sns.stripplot(df_patches, y = "z_pc_1", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[4], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[4].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "z_pc_2", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[5], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[5].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_1", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[6], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[6].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_2", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[7], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[7].axvline(zorder=2, color='black', linestyle = 'dashed')


## formatting

for subax in ax_inner_topright:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_yticklabels([])
    subax.set_yticks([])
    subax.set_xlabel('UMAP 1', fontsize=fontsize*0.6)
    subax.set_ylabel('UMAP 2', fontsize=fontsize*0.6)
    try:
        force_aspect(subax)
    except:
        pass

for subax in ax_inner_topleft:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_yticklabels([])
    subax.set_yticks([])
    subax.set_xlabel('UMAP 1', fontsize=fontsize*0.6)
    subax.set_ylabel('UMAP 2', fontsize=fontsize*0.6)
    try:
        force_aspect(subax)
    except:
        pass

for subax in ax_inner_botleft:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_yticklabels([])
    subax.set_yticks([])
    subax.set_xlabel('UMAP 1', fontsize=fontsize*0.6)
    subax.set_ylabel('UMAP 2', fontsize=fontsize*0.6)
    try:
        force_aspect(subax)
    except:
        pass

for subax in ax_inner_botright:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_xlabel('')
    subax.set_ylabel('')

ax_inner_botright[0].set_ylabel('Principal Score', fontsize=fontsize*0.6)
ax_inner_botright[4].set_ylabel('Principal Score', fontsize=fontsize*0.6)

clu_h, clu_l = clu.get_legend_handles_labels() 
con_h, con_l = con.get_legend_handles_labels()
sam_h, sam_l = sam.get_legend_handles_labels()

clu.legend([], frameon=False); con.legend([], frameon=False); sam.legend([], frameon=False)

ax_inner_topleft[0].set_title('Normalized Counts', fontsize=fontsize)
ax_inner_topleft[1].set_title('Cell Identity (ρ)', fontsize=fontsize)
ax_inner_topleft[2].set_title('Common (Z)', fontsize=fontsize)
ax_inner_topleft[3].set_title('Conditional (W)', fontsize=fontsize)

ax_inner_topright[0].set_title('Normalized Counts', fontsize=fontsize)
ax_inner_topright[1].set_title('Cell Identity (ρ)', fontsize=fontsize)
ax_inner_topright[2].set_title('Common (Z)', fontsize=fontsize)
ax_inner_topright[3].set_title('Conditional (W)', fontsize=fontsize)

ax_inner_botleft[0].set_title('Normalized Counts', fontsize=fontsize)
ax_inner_botleft[1].set_title('Cell Identity (ρ)', fontsize=fontsize)
ax_inner_botleft[2].set_title('Common (Z)', fontsize=fontsize)
ax_inner_botleft[3].set_title('Conditional (W)', fontsize=fontsize)

ax_inner_botright[0].set_title('Z - PC 1', fontsize=fontsize)
ax_inner_botright[1].set_title('Z - PC 2', fontsize=fontsize)
ax_inner_botright[2].set_title('W - PC 1', fontsize=fontsize)
ax_inner_botright[3].set_title('W - PC 2', fontsize=fontsize)
ax_inner_botright[4].set_title('Z - PC 1', fontsize=fontsize)
ax_inner_botright[5].set_title('Z - PC 2', fontsize=fontsize)
ax_inner_botright[6].set_title('W - PC 1', fontsize=fontsize)
ax_inner_botright[7].set_title('W - PC 2', fontsize=fontsize)


## set subplot titles
titles = ['Clusters', 'Conditions', 'Samples', '']

for a, t in zip(ax, titles):
    a.set_title(t, y=1.05, fontsize=fontsize*1.2)


## save and show figure
fig.suptitle("Patches - Conditions Only - Full Data", fontsize=18, y=0.95)
plt.savefig("../../outs/sim/t100,s80,b0-patches_con_full.png", dpi=300, bbox_inches='tight')
plt.show()

### HVG Data + Condtion (State) Only - Workflow2

In [None]:
# initialize workflow object
workflow2 = InterpretableWorkflow(
    adata_hvg.copy(),
    verbose=True,
    random_seed=random_seed
)

# define the condition classes & batch key to prepare the data
factors = ["group_id"]
workflow2.prep_model(
    factors,
    batch_key=batch_key,
    minibatch_size=minibatch_size_hvg,
    model_type="Patches",
    model_args={"ld_normalize" : True},
    optim_args={"lr" : lr_hvg})

# run model
workflow2.run_model(max_epochs=epochs_hvg, convergence_threshold=convergence_threshold, convergence_window=convergence_window)
workflow2.save_model("../../data/sim/02-scores_patches/t100,s80,b0-patches_con_hvg")

In [None]:
# inspect loss
workflow2.plot_loss()

In [None]:
# save embeddings to adata object
workflow2.write_embeddings()
workflow2.anndata.obsm

In [None]:
# evaluate run
workflow2.evaluate_reconstruction()

In [None]:
# get loadings (gene scores)
workflow2.get_conditional_loadings()
workflow2.get_common_loadings()
workflow2.anndata.var

In [None]:
# save loadings to file
workflow2.anndata.var.loc[:, ["Condition1_score_Patches", "Condition2_score_Patches", "common_score_Patches"]].to_csv(
    "../../data/sim/02-scores_patches/t100,s80,b0-patches_con_hvg_loadings.csv"
)

In [None]:
# Figure skeleton (adapted from Patches tutorial docs)


## create dataframes for plotting
df_patches = create_umap_df(workflow2, "Patches")
df_base = create_umap_df(workflow2, "Base")


## color palettes
klee_palette = [
    "#8B1E3F",  # Deep Burgundy
    "#3B5998",  # Rich Blue
    "#F4A261",  # Warm Orange
    "#264653",  # Deep Teal
    "#E9C46A",  # Soft Yellow
    "#2A9D8F",  # Muted Green
    "#E76F51",  # Burnt Sienna
    "#D3D9E3",  # Soft Pastel Blue
    "#A8DADC",  # Pale Turquoise
    "#BC4749",  # Warm Cranberry Red
]

klee_palette_masch = [
    "#3B5998",  # Rich Blue
    "#6A994E",  # Fresh Olive Green
    "#F4A261",  # Warm Orange
    "#E9C46A",  # Soft Yellow
    "#2A9D8F",  # Muted Green
    "#E76F51",  # Burnt Sienna
    "#FFC8A2",  # Soft Peach
    "#A8DADC",  # Pale Turquoise
    "#BC4749",  # Warm Cranberry Red
]


## plot parameters
fontsize=14
alpha=0.3
s=10
s_pca=3


## create a figure with a 2x2 grid of subplots
fig = plt.figure(figsize=(21, 21))

## define a GridSpec with a 2x2 layout
gs = gridspec.GridSpec(2, 2, wspace=0.17, hspace = 0.3, figure=fig)

## create subplots for the 2x2 grid
ax = [fig.add_subplot(gs[i//2, i%2]) for i in range(4)]

for subax in ax:
    subax.axis('off')

## define a new GridSpec for axis to split vertically
gs_inner_topleft = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0, 0], wspace=0.1, hspace=0.15)
gs_inner_topright = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0, 1], wspace=0.1, hspace=0.15)
gs_inner_botleft = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[1, 0], wspace=0.1, hspace=0.15)
gs_inner_botright = gridspec.GridSpecFromSubplotSpec(2, 4, subplot_spec=gs[1, 1], wspace=0.25)

## create subplots for the inner grid
ax_inner_topleft = [fig.add_subplot(gs_inner_topleft[i//2, i%2]) for i in range(4)]
ax_inner_topright = [fig.add_subplot(gs_inner_topright[i//2, i%2]) for i in range(4)]
ax_inner_botleft = [fig.add_subplot(gs_inner_botleft[i//2, i%2]) for i in range(4)]

## specific for botright
ax_inner_botright = [fig.add_subplot(gs_inner_botright[0,0])]
ax_inner_botright = ax_inner_botright \
+ [
    fig.add_subplot(gs_inner_botright[0,1],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[0,2],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[0,3],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,0],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,1],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,2],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,3],sharey=ax_inner_botright[0]),
]


## UMAP plots

### counts
clu = sns.scatterplot(df_base, x='base_umap_1', y='base_umap_2', ax=ax_inner_topright[0], hue = 'group_id', palette=sns.color_palette(klee_palette[0:2]), s=s, alpha=alpha)
con = sns.scatterplot(df_base, x='base_umap_1', y='base_umap_2', ax=ax_inner_topleft[0], hue = 'cluster_id', palette=sns.color_palette(klee_palette[3:6]), s=s, alpha=alpha)
sam = sns.scatterplot(df_base, x='base_umap_1', y='base_umap_2', ax=ax_inner_botleft[0], hue = 'sample_id', palette=sns.color_palette(klee_palette[0:6]), s=s, alpha=alpha)

### Zs
sns.scatterplot(df_patches, x='z_umap_1', y='z_umap_2', ax=ax_inner_topright[2], hue = 'group_id', palette=sns.color_palette(klee_palette[0:2]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='z_umap_1', y='z_umap_2', ax=ax_inner_topleft[2], hue = 'cluster_id', palette=sns.color_palette(klee_palette[3:6]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='z_umap_1', y='z_umap_2', ax=ax_inner_botleft[2], hue = 'sample_id', palette=sns.color_palette(klee_palette[0:6]), s=s, alpha=alpha, legend=False)

### Ws
sns.scatterplot(df_patches, x='w_umap_1', y='w_umap_2', ax=ax_inner_topright[3], hue = 'group_id', palette=sns.color_palette(klee_palette[0:2]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='w_umap_1', y='w_umap_2', ax=ax_inner_topleft[3], hue = 'cluster_id', palette=sns.color_palette(klee_palette[3:6]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='w_umap_1', y='w_umap_2', ax=ax_inner_botleft[3], hue = 'sample_id', palette=sns.color_palette(klee_palette[0:6]), s=s, alpha=alpha, legend=False)


## PCA plots

### cluster
sns.stripplot(df_patches, y = "z_pc_1", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[0], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[0].axvline(zorder=2, color='black', linestyle = 'dashed')
    
sns.stripplot(df_patches, y = "z_pc_2", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[1], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[1].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_1", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[2], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[2].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_2", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[3], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[3].axvline(zorder=2, color='black', linestyle = 'dashed')


### condition
sns.stripplot(df_patches, y = "z_pc_1", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[4], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[4].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "z_pc_2", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[5], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[5].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_1", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[6], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[6].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_2", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[7], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[7].axvline(zorder=2, color='black', linestyle = 'dashed')


## formatting

for subax in ax_inner_topright:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_yticklabels([])
    subax.set_yticks([])
    subax.set_xlabel('UMAP 1', fontsize=fontsize*0.6)
    subax.set_ylabel('UMAP 2', fontsize=fontsize*0.6)
    try:
        force_aspect(subax)
    except:
        pass

for subax in ax_inner_topleft:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_yticklabels([])
    subax.set_yticks([])
    subax.set_xlabel('UMAP 1', fontsize=fontsize*0.6)
    subax.set_ylabel('UMAP 2', fontsize=fontsize*0.6)
    try:
        force_aspect(subax)
    except:
        pass

for subax in ax_inner_botleft:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_yticklabels([])
    subax.set_yticks([])
    subax.set_xlabel('UMAP 1', fontsize=fontsize*0.6)
    subax.set_ylabel('UMAP 2', fontsize=fontsize*0.6)
    try:
        force_aspect(subax)
    except:
        pass

for subax in ax_inner_botright:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_xlabel('')
    subax.set_ylabel('')

ax_inner_botright[0].set_ylabel('Principal Score', fontsize=fontsize*0.6)
ax_inner_botright[4].set_ylabel('Principal Score', fontsize=fontsize*0.6)

clu_h, clu_l = clu.get_legend_handles_labels() 
con_h, con_l = con.get_legend_handles_labels()
sam_h, sam_l = sam.get_legend_handles_labels()

clu.legend([], frameon=False); con.legend([], frameon=False); sam.legend([], frameon=False)

ax_inner_topleft[0].set_title('Normalized Counts', fontsize=fontsize)
ax_inner_topleft[1].set_title('Cell Identity (ρ)', fontsize=fontsize)
ax_inner_topleft[2].set_title('Common (Z)', fontsize=fontsize)
ax_inner_topleft[3].set_title('Conditional (W)', fontsize=fontsize)

ax_inner_topright[0].set_title('Normalized Counts', fontsize=fontsize)
ax_inner_topright[1].set_title('Cell Identity (ρ)', fontsize=fontsize)
ax_inner_topright[2].set_title('Common (Z)', fontsize=fontsize)
ax_inner_topright[3].set_title('Conditional (W)', fontsize=fontsize)

ax_inner_botleft[0].set_title('Normalized Counts', fontsize=fontsize)
ax_inner_botleft[1].set_title('Cell Identity (ρ)', fontsize=fontsize)
ax_inner_botleft[2].set_title('Common (Z)', fontsize=fontsize)
ax_inner_botleft[3].set_title('Conditional (W)', fontsize=fontsize)

ax_inner_botright[0].set_title('Z - PC 1', fontsize=fontsize)
ax_inner_botright[1].set_title('Z - PC 2', fontsize=fontsize)
ax_inner_botright[2].set_title('W - PC 1', fontsize=fontsize)
ax_inner_botright[3].set_title('W - PC 2', fontsize=fontsize)
ax_inner_botright[4].set_title('Z - PC 1', fontsize=fontsize)
ax_inner_botright[5].set_title('Z - PC 2', fontsize=fontsize)
ax_inner_botright[6].set_title('W - PC 1', fontsize=fontsize)
ax_inner_botright[7].set_title('W - PC 2', fontsize=fontsize)


## set subplot titles
titles = ['Clusters', 'Conditions', 'Samples', '']

for a, t in zip(ax, titles):
    a.set_title(t, y=1.05, fontsize=fontsize*1.2)


## save and show figure
fig.suptitle("Patches - Conditions Only - HVG Data", fontsize=18, y=0.95)
plt.savefig("../../outs/sim/t100,s80,b0-patches_con_hvg.png", dpi=300, bbox_inches='tight')
plt.show()

### Full Data + Cluster (Type) & Condtion (State) - Workflow3

In [None]:
# initialize workflow object
workflow3 = InterpretableWorkflow(
    adata_full.copy(),
    verbose=True,
    random_seed=random_seed
)

# define the condition classes & batch key to prepare the data
factors = ["cluster_id", "group_id"]
workflow3.prep_model(
    factors,
    batch_key=batch_key,
    minibatch_size=minibatch_size_full,
    model_type="Patches",
    model_args={"ld_normalize" : True},
    optim_args={"lr" : lr_full})

# run model
workflow3.run_model(max_epochs=epochs_full, convergence_threshold=convergence_threshold, convergence_window=convergence_window)
workflow3.save_model("../../data/sim/02-scores_patches/t100,s80,b0-patches_clu_con_full")

In [None]:
# inspect loss
workflow3.plot_loss()

In [None]:
# save embeddings to adata object
workflow3.write_embeddings()
workflow3.anndata.obsm

In [None]:
# evaluate run
workflow3.evaluate_reconstruction()

In [None]:
# get loadings (gene scores)
workflow3.get_conditional_loadings()
workflow3.get_common_loadings()
workflow3.anndata.var

In [None]:
# save loadings to file
workflow3.anndata.var.loc[:, ["Group1_score_Patches", "Group2_score_Patches", "Group3_score_Patches", "Condition1_score_Patches", "Condition2_score_Patches", "common_score_Patches"]].to_csv(
    "../../data/sim/02-scores_patches/t100,s80,b0-patches_clu_con_full_loadings.csv"
)

In [None]:
# Figure skeleton (adapted from Patches tutorial docs)


## create dataframes for plotting
df_patches = create_umap_df(workflow3, "Patches")
df_base = create_umap_df(workflow3, "Base")


## color palettes
klee_palette = [
    "#8B1E3F",  # Deep Burgundy
    "#3B5998",  # Rich Blue
    "#F4A261",  # Warm Orange
    "#264653",  # Deep Teal
    "#E9C46A",  # Soft Yellow
    "#2A9D8F",  # Muted Green
    "#E76F51",  # Burnt Sienna
    "#D3D9E3",  # Soft Pastel Blue
    "#A8DADC",  # Pale Turquoise
    "#BC4749",  # Warm Cranberry Red
]

klee_palette_masch = [
    "#3B5998",  # Rich Blue
    "#6A994E",  # Fresh Olive Green
    "#F4A261",  # Warm Orange
    "#E9C46A",  # Soft Yellow
    "#2A9D8F",  # Muted Green
    "#E76F51",  # Burnt Sienna
    "#FFC8A2",  # Soft Peach
    "#A8DADC",  # Pale Turquoise
    "#BC4749",  # Warm Cranberry Red
]


## plot parameters
fontsize=14
alpha=0.3
s=10
s_pca=3


## create a figure with a 2x2 grid of subplots
fig = plt.figure(figsize=(21, 21))

## define a GridSpec with a 2x2 layout
gs = gridspec.GridSpec(2, 2, wspace=0.17, hspace = 0.3, figure=fig)

## create subplots for the 2x2 grid
ax = [fig.add_subplot(gs[i//2, i%2]) for i in range(4)]

for subax in ax:
    subax.axis('off')

## define a new GridSpec for axis to split vertically
gs_inner_topleft = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0, 0], wspace=0.1, hspace=0.15)
gs_inner_topright = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0, 1], wspace=0.1, hspace=0.15)
gs_inner_botleft = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[1, 0], wspace=0.1, hspace=0.15)
gs_inner_botright = gridspec.GridSpecFromSubplotSpec(2, 4, subplot_spec=gs[1, 1], wspace=0.25)

## create subplots for the inner grid
ax_inner_topleft = [fig.add_subplot(gs_inner_topleft[i//2, i%2]) for i in range(4)]
ax_inner_topright = [fig.add_subplot(gs_inner_topright[i//2, i%2]) for i in range(4)]
ax_inner_botleft = [fig.add_subplot(gs_inner_botleft[i//2, i%2]) for i in range(4)]

## specific for botright
ax_inner_botright = [fig.add_subplot(gs_inner_botright[0,0])]
ax_inner_botright = ax_inner_botright \
+ [
    fig.add_subplot(gs_inner_botright[0,1],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[0,2],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[0,3],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,0],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,1],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,2],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,3],sharey=ax_inner_botright[0]),
]


## UMAP plots

### counts
clu = sns.scatterplot(df_base, x='base_umap_1', y='base_umap_2', ax=ax_inner_topright[0], hue = 'group_id', palette=sns.color_palette(klee_palette[0:2]), s=s, alpha=alpha)
con = sns.scatterplot(df_base, x='base_umap_1', y='base_umap_2', ax=ax_inner_topleft[0], hue = 'cluster_id', palette=sns.color_palette(klee_palette[3:6]), s=s, alpha=alpha)
sam = sns.scatterplot(df_base, x='base_umap_1', y='base_umap_2', ax=ax_inner_botleft[0], hue = 'sample_id', palette=sns.color_palette(klee_palette[0:6]), s=s, alpha=alpha)

### Zs
sns.scatterplot(df_patches, x='z_umap_1', y='z_umap_2', ax=ax_inner_topright[2], hue = 'group_id', palette=sns.color_palette(klee_palette[0:2]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='z_umap_1', y='z_umap_2', ax=ax_inner_topleft[2], hue = 'cluster_id', palette=sns.color_palette(klee_palette[3:6]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='z_umap_1', y='z_umap_2', ax=ax_inner_botleft[2], hue = 'sample_id', palette=sns.color_palette(klee_palette[0:6]), s=s, alpha=alpha, legend=False)

### Ws
sns.scatterplot(df_patches, x='w_umap_1', y='w_umap_2', ax=ax_inner_topright[3], hue = 'group_id', palette=sns.color_palette(klee_palette[0:2]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='w_umap_1', y='w_umap_2', ax=ax_inner_topleft[3], hue = 'cluster_id', palette=sns.color_palette(klee_palette[3:6]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='w_umap_1', y='w_umap_2', ax=ax_inner_botleft[3], hue = 'sample_id', palette=sns.color_palette(klee_palette[0:6]), s=s, alpha=alpha, legend=False)


## PCA plots

### cluster
sns.stripplot(df_patches, y = "z_pc_1", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[0], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[0].axvline(zorder=2, color='black', linestyle = 'dashed')
    
sns.stripplot(df_patches, y = "z_pc_2", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[1], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[1].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_1", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[2], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[2].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_2", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[3], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[3].axvline(zorder=2, color='black', linestyle = 'dashed')


### condition
sns.stripplot(df_patches, y = "z_pc_1", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[4], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[4].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "z_pc_2", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[5], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[5].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_1", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[6], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[6].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_2", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[7], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[7].axvline(zorder=2, color='black', linestyle = 'dashed')


## formatting

for subax in ax_inner_topright:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_yticklabels([])
    subax.set_yticks([])
    subax.set_xlabel('UMAP 1', fontsize=fontsize*0.6)
    subax.set_ylabel('UMAP 2', fontsize=fontsize*0.6)
    try:
        force_aspect(subax)
    except:
        pass

for subax in ax_inner_topleft:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_yticklabels([])
    subax.set_yticks([])
    subax.set_xlabel('UMAP 1', fontsize=fontsize*0.6)
    subax.set_ylabel('UMAP 2', fontsize=fontsize*0.6)
    try:
        force_aspect(subax)
    except:
        pass

for subax in ax_inner_botleft:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_yticklabels([])
    subax.set_yticks([])
    subax.set_xlabel('UMAP 1', fontsize=fontsize*0.6)
    subax.set_ylabel('UMAP 2', fontsize=fontsize*0.6)
    try:
        force_aspect(subax)
    except:
        pass

for subax in ax_inner_botright:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_xlabel('')
    subax.set_ylabel('')

ax_inner_botright[0].set_ylabel('Principal Score', fontsize=fontsize*0.6)
ax_inner_botright[4].set_ylabel('Principal Score', fontsize=fontsize*0.6)

clu_h, clu_l = clu.get_legend_handles_labels() 
con_h, con_l = con.get_legend_handles_labels()
sam_h, sam_l = sam.get_legend_handles_labels()

clu.legend([], frameon=False); con.legend([], frameon=False); sam.legend([], frameon=False)

ax_inner_topleft[0].set_title('Normalized Counts', fontsize=fontsize)
ax_inner_topleft[1].set_title('Cell Identity (ρ)', fontsize=fontsize)
ax_inner_topleft[2].set_title('Common (Z)', fontsize=fontsize)
ax_inner_topleft[3].set_title('Conditional (W)', fontsize=fontsize)

ax_inner_topright[0].set_title('Normalized Counts', fontsize=fontsize)
ax_inner_topright[1].set_title('Cell Identity (ρ)', fontsize=fontsize)
ax_inner_topright[2].set_title('Common (Z)', fontsize=fontsize)
ax_inner_topright[3].set_title('Conditional (W)', fontsize=fontsize)

ax_inner_botleft[0].set_title('Normalized Counts', fontsize=fontsize)
ax_inner_botleft[1].set_title('Cell Identity (ρ)', fontsize=fontsize)
ax_inner_botleft[2].set_title('Common (Z)', fontsize=fontsize)
ax_inner_botleft[3].set_title('Conditional (W)', fontsize=fontsize)

ax_inner_botright[0].set_title('Z - PC 1', fontsize=fontsize)
ax_inner_botright[1].set_title('Z - PC 2', fontsize=fontsize)
ax_inner_botright[2].set_title('W - PC 1', fontsize=fontsize)
ax_inner_botright[3].set_title('W - PC 2', fontsize=fontsize)
ax_inner_botright[4].set_title('Z - PC 1', fontsize=fontsize)
ax_inner_botright[5].set_title('Z - PC 2', fontsize=fontsize)
ax_inner_botright[6].set_title('W - PC 1', fontsize=fontsize)
ax_inner_botright[7].set_title('W - PC 2', fontsize=fontsize)


## set subplot titles
titles = ['Clusters', 'Conditions', 'Samples', '']

for a, t in zip(ax, titles):
    a.set_title(t, y=1.05, fontsize=fontsize*1.2)

    
## save and show figure
fig.suptitle("Patches - Cluster + Condition - Full Data", fontsize=18, y=0.95)
plt.savefig("../../outs/sim/t100,s80,b0-patches_clu_con_full.png", dpi=300, bbox_inches='tight')
plt.show()

### HVG Data + Cluster (Type) & Condtion (State) - Workflow4

In [None]:
# initialize workflow object
workflow4 = InterpretableWorkflow(
    adata_hvg.copy(),
    verbose=True,
    random_seed=random_seed
)

# define the condition classes & batch key to prepare the data
factors = ["cluster_id", "group_id"]
workflow4.prep_model(
    factors,
    batch_key=batch_key,
    minibatch_size=minibatch_size_hvg,
    model_type="Patches",
    model_args={"ld_normalize" : True},
    optim_args={"lr" : lr_hvg})

# run model
workflow4.run_model(max_epochs=epochs_hvg, convergence_threshold=convergence_threshold, convergence_window=convergence_window)
workflow4.save_model("../../data/sim/02-scores_patches/t100,s80,b0-patches_clu_con_hvg")

In [None]:
# inspect loss
workflow4.plot_loss()

In [None]:
# save embeddings to adata object
workflow4.write_embeddings()
workflow4.anndata.obsm

In [None]:
# evaluate run
workflow4.evaluate_reconstruction()

In [None]:
# get loadings (gene scores)
workflow4.get_conditional_loadings()
workflow4.get_common_loadings()
workflow4.anndata.var

In [None]:
# save loadings to file
workflow4.anndata.var.loc[:, ["Group1_score_Patches", "Group2_score_Patches", "Group3_score_Patches", "Condition1_score_Patches", "Condition2_score_Patches", "common_score_Patches"]].to_csv(
    "../../data/sim/02-scores_patches/t100,s80,b0-patches_clu_con_hvg_loadings.csv"
)

In [None]:
# Figure skeleton (adapted from Patches tutorial docs)


## create dataframes for plotting
df_patches = create_umap_df(workflow4, "Patches")
df_base = create_umap_df(workflow4, "Base")


## color palettes
klee_palette = [
    "#8B1E3F",  # Deep Burgundy
    "#3B5998",  # Rich Blue
    "#F4A261",  # Warm Orange
    "#264653",  # Deep Teal
    "#E9C46A",  # Soft Yellow
    "#2A9D8F",  # Muted Green
    "#E76F51",  # Burnt Sienna
    "#D3D9E3",  # Soft Pastel Blue
    "#A8DADC",  # Pale Turquoise
    "#BC4749",  # Warm Cranberry Red
]

klee_palette_masch = [
    "#3B5998",  # Rich Blue
    "#6A994E",  # Fresh Olive Green
    "#F4A261",  # Warm Orange
    "#E9C46A",  # Soft Yellow
    "#2A9D8F",  # Muted Green
    "#E76F51",  # Burnt Sienna
    "#FFC8A2",  # Soft Peach
    "#A8DADC",  # Pale Turquoise
    "#BC4749",  # Warm Cranberry Red
]


## plot parameters
fontsize=14
alpha=0.3
s=10
s_pca=3


## create a figure with a 2x2 grid of subplots
fig = plt.figure(figsize=(21, 21))

## define a GridSpec with a 2x2 layout
gs = gridspec.GridSpec(2, 2, wspace=0.17, hspace = 0.3, figure=fig)

## create subplots for the 2x2 grid
ax = [fig.add_subplot(gs[i//2, i%2]) for i in range(4)]

for subax in ax:
    subax.axis('off')

## define a new GridSpec for axis to split vertically
gs_inner_topleft = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0, 0], wspace=0.1, hspace=0.15)
gs_inner_topright = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0, 1], wspace=0.1, hspace=0.15)
gs_inner_botleft = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[1, 0], wspace=0.1, hspace=0.15)
gs_inner_botright = gridspec.GridSpecFromSubplotSpec(2, 4, subplot_spec=gs[1, 1], wspace=0.25)

## create subplots for the inner grid
ax_inner_topleft = [fig.add_subplot(gs_inner_topleft[i//2, i%2]) for i in range(4)]
ax_inner_topright = [fig.add_subplot(gs_inner_topright[i//2, i%2]) for i in range(4)]
ax_inner_botleft = [fig.add_subplot(gs_inner_botleft[i//2, i%2]) for i in range(4)]

## specific for botright
ax_inner_botright = [fig.add_subplot(gs_inner_botright[0,0])]
ax_inner_botright = ax_inner_botright \
+ [
    fig.add_subplot(gs_inner_botright[0,1],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[0,2],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[0,3],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,0],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,1],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,2],sharey=ax_inner_botright[0]),
    fig.add_subplot(gs_inner_botright[1,3],sharey=ax_inner_botright[0]),
]


## UMAP plots

### counts
clu = sns.scatterplot(df_base, x='base_umap_1', y='base_umap_2', ax=ax_inner_topright[0], hue = 'group_id', palette=sns.color_palette(klee_palette[0:2]), s=s, alpha=alpha)
con = sns.scatterplot(df_base, x='base_umap_1', y='base_umap_2', ax=ax_inner_topleft[0], hue = 'cluster_id', palette=sns.color_palette(klee_palette[3:6]), s=s, alpha=alpha)
sam = sns.scatterplot(df_base, x='base_umap_1', y='base_umap_2', ax=ax_inner_botleft[0], hue = 'sample_id', palette=sns.color_palette(klee_palette[0:6]), s=s, alpha=alpha)

### Zs
sns.scatterplot(df_patches, x='z_umap_1', y='z_umap_2', ax=ax_inner_topright[2], hue = 'group_id', palette=sns.color_palette(klee_palette[0:2]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='z_umap_1', y='z_umap_2', ax=ax_inner_topleft[2], hue = 'cluster_id', palette=sns.color_palette(klee_palette[3:6]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='z_umap_1', y='z_umap_2', ax=ax_inner_botleft[2], hue = 'sample_id', palette=sns.color_palette(klee_palette[0:6]), s=s, alpha=alpha, legend=False)

### Ws
sns.scatterplot(df_patches, x='w_umap_1', y='w_umap_2', ax=ax_inner_topright[3], hue = 'group_id', palette=sns.color_palette(klee_palette[0:2]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='w_umap_1', y='w_umap_2', ax=ax_inner_topleft[3], hue = 'cluster_id', palette=sns.color_palette(klee_palette[3:6]), s=s, alpha=alpha, legend=False)
sns.scatterplot(df_patches, x='w_umap_1', y='w_umap_2', ax=ax_inner_botleft[3], hue = 'sample_id', palette=sns.color_palette(klee_palette[0:6]), s=s, alpha=alpha, legend=False)


## PCA plots

### cluster
sns.stripplot(df_patches, y = "z_pc_1", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[0], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[0].axvline(zorder=2, color='black', linestyle = 'dashed')
    
sns.stripplot(df_patches, y = "z_pc_2", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[1], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[1].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_1", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[2], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[2].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_2", hue='cluster_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[3], legend=False, palette=sns.color_palette(klee_palette[3:6]))
#ax_inner_botright[3].axvline(zorder=2, color='black', linestyle = 'dashed')


### condition
sns.stripplot(df_patches, y = "z_pc_1", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[4], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[4].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "z_pc_2", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[5], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[5].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_1", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[6], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[6].axvline(zorder=2, color='black', linestyle = 'dashed')

sns.stripplot(df_patches, y = "w_pc_2", hue='group_id', zorder=1, alpha=alpha, s=s_pca, ax=ax_inner_botright[7], legend=False, palette=sns.color_palette(klee_palette[0:2]))
#ax_inner_botright[7].axvline(zorder=2, color='black', linestyle = 'dashed')


## formatting

for subax in ax_inner_topright:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_yticklabels([])
    subax.set_yticks([])
    subax.set_xlabel('UMAP 1', fontsize=fontsize*0.6)
    subax.set_ylabel('UMAP 2', fontsize=fontsize*0.6)
    try:
        force_aspect(subax)
    except:
        pass

for subax in ax_inner_topleft:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_yticklabels([])
    subax.set_yticks([])
    subax.set_xlabel('UMAP 1', fontsize=fontsize*0.6)
    subax.set_ylabel('UMAP 2', fontsize=fontsize*0.6)
    try:
        force_aspect(subax)
    except:
        pass

for subax in ax_inner_botleft:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_yticklabels([])
    subax.set_yticks([])
    subax.set_xlabel('UMAP 1', fontsize=fontsize*0.6)
    subax.set_ylabel('UMAP 2', fontsize=fontsize*0.6)
    try:
        force_aspect(subax)
    except:
        pass

for subax in ax_inner_botright:
    subax.set_xticklabels([])
    subax.set_xticks([])
    subax.set_xlabel('')
    subax.set_ylabel('')

ax_inner_botright[0].set_ylabel('Principal Score', fontsize=fontsize*0.6)
ax_inner_botright[4].set_ylabel('Principal Score', fontsize=fontsize*0.6)

clu_h, clu_l = clu.get_legend_handles_labels() 
con_h, con_l = con.get_legend_handles_labels()
sam_h, sam_l = sam.get_legend_handles_labels()

clu.legend([], frameon=False); con.legend([], frameon=False); sam.legend([], frameon=False)

ax_inner_topleft[0].set_title('Normalized Counts', fontsize=fontsize)
ax_inner_topleft[1].set_title('Cell Identity (ρ)', fontsize=fontsize)
ax_inner_topleft[2].set_title('Common (Z)', fontsize=fontsize)
ax_inner_topleft[3].set_title('Conditional (W)', fontsize=fontsize)

ax_inner_topright[0].set_title('Normalized Counts', fontsize=fontsize)
ax_inner_topright[1].set_title('Cell Identity (ρ)', fontsize=fontsize)
ax_inner_topright[2].set_title('Common (Z)', fontsize=fontsize)
ax_inner_topright[3].set_title('Conditional (W)', fontsize=fontsize)

ax_inner_botleft[0].set_title('Normalized Counts', fontsize=fontsize)
ax_inner_botleft[1].set_title('Cell Identity (ρ)', fontsize=fontsize)
ax_inner_botleft[2].set_title('Common (Z)', fontsize=fontsize)
ax_inner_botleft[3].set_title('Conditional (W)', fontsize=fontsize)

ax_inner_botright[0].set_title('Z - PC 1', fontsize=fontsize)
ax_inner_botright[1].set_title('Z - PC 2', fontsize=fontsize)
ax_inner_botright[2].set_title('W - PC 1', fontsize=fontsize)
ax_inner_botright[3].set_title('W - PC 2', fontsize=fontsize)
ax_inner_botright[4].set_title('Z - PC 1', fontsize=fontsize)
ax_inner_botright[5].set_title('Z - PC 2', fontsize=fontsize)
ax_inner_botright[6].set_title('W - PC 1', fontsize=fontsize)
ax_inner_botright[7].set_title('W - PC 2', fontsize=fontsize)


## set subplot titles
titles = ['Clusters', 'Conditions', 'Samples', '']

for a, t in zip(ax, titles):
    a.set_title(t, y=1.05, fontsize=fontsize*1.2)

    
## save and show figure
fig.suptitle("Patches - Cluster + Condition - HVG Data", fontsize=18, y=0.95)
plt.savefig("../../outs/sim/t100,s80,b0-patches_clu_con_hvg.png", dpi=300, bbox_inches='tight')
plt.show()

## Old Code

In [None]:
# sc.pp.neighbors(workflow.anndata, use_rep="PCA")
# sc.tl.umap(workflow.anndata)
# sc.pl.umap(workflow.anndata, color=["group_id", "cluster_id"], title=["PCA", "PCA"])

In [None]:
# sc.pp.neighbors(workflow.anndata, use_rep="patches_w_latent")
# sc.tl.umap(workflow.anndata)
# sc.pl.umap(workflow.anndata, color=["group_id", "cluster_id"], title=["Conditional (W)", "Conditional (W)"])

In [None]:
# sc.pp.neighbors(workflow.anndata, use_rep="patches_z_latent")
# sc.tl.umap(workflow.anndata)
# sc.pl.umap(workflow.anndata, color=["group_id", "cluster_id"], title=["Common (Z)", "Common (Z)"])