In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
from cellwhisperer.config import get_path
import numpy as np
import seaborn as sns


matplotlib.style.use(get_path(["plot_style"]))
adata = sc.read_h5ad(snakemake.input[0])


In [None]:
# Drop 57 observation for which the date is missing
adata.obs["series_submission_date"] = pd.to_datetime(adata.obs["series_submission_date"])
adata.obsm["X_umap"] = adata.obsm["X_cellwhisperer_umap"]  # need correct name for plotting

## Plot UMAP with labeled clusters

In [None]:
# doudble check everything is set correctly
sc.set_figure_params(vector_friendly=True, dpi_save=400, scanpy=False)

print("Current font sizes:")
print(f"Font size: {plt.rcParams['font.size']}")
print(f"Title font size: {plt.rcParams['axes.titlesize']}")
print(f"Axis label font size: {plt.rcParams['axes.labelsize']}")
print(f"X-tick label font size: {plt.rcParams['xtick.labelsize']}")
print(f"Y-tick label font size: {plt.rcParams['ytick.labelsize']}")
print(f"Legend font size: {plt.rcParams['legend.fontsize']}")
# plt.rcParams

In [None]:
fig, ax = plt.subplots(figsize=(5, 5.2))
# plot umap
import matplotlib.cm as cm
import random
from server.common.colors import CSS4_NAMED_COLORS  # cellxgene import
import itertools


n_categories = adata.obs['leiden'].nunique()
css4_palette = list(itertools.islice(itertools.cycle(CSS4_NAMED_COLORS), n_categories))

random.seed(11) # 8 also works fine (the colors differ from the UMAP below unfort..)
random.shuffle(css4_palette)

sc.pl.umap(adata, color=["leiden"], legend_loc="on data", ax=ax, 
           palette=css4_palette, size=0.5,
           legend_fontweight="normal", 
           legend_fontsize=6,
          )

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
# plot umap



highlight_mask = adata.obs['cluster_label'].isin(snakemake.params.highlight_clusters)

# Then, sort the mask in ascending order so that False (non-highlighted) comes before True (highlighted)
sorted_indices = highlight_mask.argsort()

sc.pl.umap(adata[sorted_indices], color=["cluster_label"], ax=ax, 
           palette=css4_palette,
           # groups=snakemake.params.highlight_clusters,
           size=2,
           legend_fontweight="normal", 
           legend_fontsize=6,
           show=False,
           sort_order=False
          )
ax.get_legend().remove()
for cluster in snakemake.params.highlight_clusters:
    location_cells = adata[adata.obs.cluster_label == cluster, :].obsm["X_umap"]
    x = location_cells[:, 0].mean()
    y = location_cells[:, 1].mean()
    size = np.sqrt(sum(location_cells.var(axis=0)))  # Set circle size
    # Plot circle
    circle = plt.Circle((x, y), size, color="k", clip_on=False, fill=False, linewidth=0.5)
    ax.add_patch(circle)


fig.savefig(snakemake.output.cluster_labeled)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
# plot umap

sc.pl.umap(adata, color=["cluster_label"], legend_loc="on data", ax=ax, 
           palette=list(matplotlib.colors.CSS4_COLORS.values()), 
           size=2,
           legend_fontweight="normal", 
           legend_fontsize=6,
           show=False
          )

for cluster in snakemake.params.highlight_clusters:
    location_cells = adata[adata.obs.cluster_label == cluster, :].obsm["X_umap"]
    x = location_cells[:, 0].mean()
    y = location_cells[:, 1].mean()
    size = np.sqrt(sum(location_cells.var(axis=0)))  # Set circle size
    # Plot circle
    circle = plt.Circle((x, y), size, color="k", clip_on=False, fill=False, linewidth=2)
    ax.add_patch(circle)

plt.tight_layout()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# Define the color for the first 50%
base_color = plt.cm.viridis(0.0)  # This could be any color you choose

# Create a new colormap
# The first 50% will be the base_color, and the remaining 50% will span the viridis colormap
colors = [base_color] * 128  # 128 is half of 256, which is the default number of colors in a colormap
colors.extend(plt.cm.viridis(np.linspace(0, 1, 128)))  # Extend with the second half of viridis

# Create a new colormap from the list of colors
custom_cmap = LinearSegmentedColormap.from_list("custom_viridis", colors)

# Test the new colormap
plt.imshow(np.linspace(0, 1, 256).reshape(1, -1), cmap=custom_cmap, aspect='auto')
plt.colorbar()
plt.show()


In [None]:
fig, ax = plt.subplots(figsize=(5.5, 5))
# plot umap

# 
highlight_mask = adata.obs['cluster_label'].isin(snakemake.params.highlight_clusters)
non_highlight_mask = ~highlight_mask

# Get the indices for highlighted and non-highlighted cells
highlight_indices = np.where(highlight_mask)[0]
non_highlight_indices = np.where(non_highlight_mask)[0]

# Shuffle within each group separately
np.random.shuffle(highlight_indices)
np.random.shuffle(non_highlight_indices)

# Concatenate the shuffled indices, with non-highlighted indices first
shuffled_indices = np.concatenate((non_highlight_indices, highlight_indices))

sc.pl.umap(adata[shuffled_indices], color=["series_submission_date_cont"], legend_loc="on data", ax=ax, sort_order=False,
           # palette=list(matplotlib.colors.CSS4_COLORS.values()), 
           cmap=custom_cmap,
           size=1,
           legend_fontweight="normal", 
           legend_fontsize=6,
           show=False
          )

for cluster in snakemake.params.highlight_clusters:
    location_cells = adata[adata.obs.cluster_label == cluster, :].obsm["X_umap"]
    x = location_cells[:, 0].mean()
    y = location_cells[:, 1].mean()
    size = np.sqrt(sum(location_cells.var(axis=0)))  # Set circle size
    circle = plt.Circle((x, y), size, color="k", clip_on=False, fill=False, linewidth=0.5)
    ax.add_patch(circle)

fig.savefig(snakemake.output.submission_date_labeled)

In [None]:
fig, axes = plt.subplots(len(snakemake.params.highlight_clusters), 1, sharex=True, figsize=(2, len(snakemake.params.highlight_clusters) * 0.7))

for ax, cluster_label in zip(axes, snakemake.params.highlight_clusters):
    cluster_dates = adata.obs[adata.obs.cluster_label == cluster_label]["series_submission_date_cont"]
    sns.histplot(cluster_dates, kde=True, ax=ax, bins=50, color='black')
    ax.set_title(cluster_label)
    ax.set(title=cluster_label, xticks=[2013, 2018, 2023])

plt.tight_layout()
fig.savefig(snakemake.output.highlighted_clusters_date_kdes)