In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os

os.chdir("../../../../")

import yaml
import anndata as ad
import matplotlib.pyplot as plt
import seaborn as sns
import mudata as md
import muon as mu
import numpy as np
import scanpy as sc
from matplotlib import gridspec
import scanpy.external as sce
from scipy import sparse
import celltypist
from matplotlib.colors import LinearSegmentedColormap

from utils.utils import *
from utils.plotting import *

sc.settings.verbosity = 0
sc.settings.set_figure_params(
    dpi=100,
    facecolor="white",
    # color_map="YlGnBu",
    frameon=False,
)
import matplotlib as mpl

sc.set_figure_params(dpi=100)
sns.set(style="white")

## Load data

In [None]:
mudata = md.read_h5mu(
    "data/anca/ustekinumab/R_ANCA_4PK4PB_ustekinumab_cd8emrm_clean.h5mu"
)
mudata

## Prepare data

In [None]:
mod_rna = mudata.mod["rna"]
mod_cite = mudata.mod["cite"]

In [None]:
sc.pl.umap(mod_rna, color="RNA_snn_res.0.75")

## Visualizations

In [None]:
trm_up = [
    "CD69",
    "CA10",
    "ITGA1",
    "ITGAE",
    "IL2",
    "IL10",
    "CXCR6",
    "CXCL13",
    "KCNK5",
    "RGS1",
    "CRTAM",
    "DUSP6",
    "PDCD1",
    "IL23R",
]
mod_rna.layers["log1p"] = mod_rna.X.copy()
sc.pp.scale(mod_rna)
sc.tl.score_genes(
    mod_rna,
    gene_list=trm_up,
    score_name="trm_up_score",
    ctrl_size=50,
    use_raw=False,
)
mod_rna.X = mod_rna.layers["log1p"].copy()

In [None]:
sc.pl.umap(
    mod_rna,
    color=["trm_up_score"],
    color_map="RdBu_r",
    vmin=-0.4,
    vmax=0.4,
    size=30,
    title="TRM score",
    ncols=1,
)

In [None]:
marker_df = pd.read_excel("data/markers/T cell markers.xlsx")
marker_df = marker_df.set_index("Cell type")

# extract positive markers
marker_db = {}
for index, row in marker_df.iterrows():
    pos_markers_1 = (
        row["positive markers 1"].replace(" ", "").split(",")
        if row["positive markers 1"] is not np.nan
        else []
    )
    pos_markers_2 = (
        row["positive markers 2"].replace(" ", "").split(",")
        if row["positive markers 2"] is not np.nan
        else []
    )
    marker_db[index] = list(set(pos_markers_1 + pos_markers_2))

marker_plotter = MarkerPlotter(mod_rna, mod_cite, marker_db, cluster_key=None)

In [None]:
fig = marker_plotter.visualize_markers(
    markers=[
        "NKG7",
        "KLRC1",
        "KLRD1",
        "KLRF1",
        "KLRB1",
        "NCR1",
        "NCAM1",
        "FGFBP2",
        "XCL1",
        "XCL2",
    ],
    dtype="rna",
    ncols=3,
    use_default_plot=True,
    # marker_plot=False
    dotplot=False,
)

In [None]:
fig = marker_plotter.visualize_markers(
    markers=[
        "CD4",
        "CD3D",
        "CD3E",
        "CD8A",
    ],
    dtype="rna",
    ncols=3,
    use_default_plot=True,
    # marker_plot=False
    dotplot=False,
)

In [None]:
for cell_type in marker_db.keys():
    fig = marker_plotter.visualize_markers(
        cell_type=cell_type,
        dtype="rna",
        ncols=3,
        use_default_plot=True,
        # marker_plot=False
        dotplot=False,
    )

In [None]:
for cell_type in marker_db.keys():
    fig = marker_plotter.visualize_markers(
        cell_type=cell_type,
        dtype="protein",
        ncols=3,
        use_default_plot=True,
        # marker_plot=False
        dotplot=False,
    )

## Annotations

In [None]:
cluster_key = "RNA_snn_res.0.75"

In [None]:
get_cluster_names(mod_rna, cluster_key)

In [None]:
annotations = {
    "0": "CD8+ Tcm",
    "1": "Tc1-like",
    "2": "Tc1-like",
    "3": "Tc1-like",
    "4": "CD8+ Tem/naive",
    "5": "NKT",
    "6": "Tc1",
}

In [None]:
# apply annotations
mod_rna.obs["cell_type_fine"] = mod_rna.obs[cluster_key].replace(annotations)

## Final visualizations

In [None]:
cell_type_key = "cell_type_fine"

In [None]:
# cell_type_order = [
#     "Tc1-like RM",
#     "CD8+ EM/CM",
#     "prolif. CD8+ RM",
#     "NKT",
#     "NKT/gdT",
# ]
# mod_rna.obs["cell_type_fine"] = pd.Categorical(
#     mod_rna.obs["cell_type_fine"], categories=cell_type_order, ordered=True
# )

In [None]:
selected_markers = {
    "general": ["CD3E", "CD4", "CD8A"],
    "Tc1": ["IFNG", "TNF", "CXCR3"],
    "Tc17": ["RORC", "CCR6"],
    "gdT": ["TRDV2", "TRGV9"],
    "NKT": ["KLRB1", "NCAM1", "XCL1"],
    "cytotoxic": ["GZMB", "PRF1"],
    "CM/naive": ["CCR7", "SELL", "KLF2", "S1PR1"],
    "prolif.": ["STMN1"],
}
# convert to ordered dict
selected_markers = collections.OrderedDict(selected_markers.items())

sc.pl.dotplot(
    mod_rna,
    var_names=selected_markers,
    groupby=cell_type_key,
    standard_scale="var",
    # color_map="Reds",
    # swap_axes=True,
    var_group_rotation=25,
)

In [None]:
mod_cite.obs[cell_type_key] = mod_rna.obs[cell_type_key]
selected_markers = [
    "CD3",
    "CD8",
    "CD4",
    "CD69",
    "CD27",
    "CCR7",
    "CD45RO",
    "CD45RA",
    "CCR6",
    "CXCR3",
]

sc.pl.dotplot(
    mod_cite,
    var_names=selected_markers,
    groupby=cell_type_key,
    standard_scale="var",
    # color_map="Reds",
    # swap_axes=True,
    var_group_rotation=25,
)

In [None]:
plot_umap(mod_rna, color="cell_type_fine", title="CD8+ EM/RM subsets", figsize=(5, 5))


In [None]:
sc.pl.umap(
    mod_rna,
    color=["trm_up_score"],
    color_map="RdBu_r",
    vmin=-0.4,
    vmax=0.4,
    size=30,
    title="TRM score",
    ncols=1,
)

In [None]:
all_marker_results, filtered_marker_results, best_markers = run_de_pipeline(
    mod_rna,
    cell_type_key,
    mod="rna",
    top_n=5,
    min_expression=0.1,
)


In [None]:
top_markers = get_top_n_markers(best_markers, n=10)
top_markers = {k: top_markers[k] for k in cell_type_order}
sc.pl.dotplot(
    mod_rna,
    var_names=top_markers,
    groupby=cell_type_key,
    standard_scale="var",
    # color_map="Reds",
    # swap_axes=True,
    var_group_rotation=25,
)


In [None]:
# remove hashb and hashk
var_names = [name for name in mod_cite.var_names if not name.startswith("Hash")]
mod_cite = mod_cite[:, var_names].copy()

mod_cite.obs[cell_type_key] = mod_rna.obs[cell_type_key].copy()
all_marker_results, filtered_marker_results, best_markers = run_de_pipeline(
    mod_cite,
    cell_type_key,
    mod="cite",
    top_n=5,
    min_expression=0.1,
)

top_markers = get_top_n_markers(best_markers, n=10)
top_markers = {k: top_markers[k] for k in cell_type_order}
sc.pl.dotplot(
    mod_cite,
    var_names=top_markers,
    groupby=cell_type_key,
    standard_scale="var",
    # color_map="Reds",
    # swap_axes=True,
    var_group_rotation=25,
)

In [None]:
color_map = {
    cell_type: mod_rna.uns["cell_type_fine_colors"][k]
    for k, cell_type in enumerate(cell_type_order)
}
donut_plot(
    adata=mod_rna,
    cell_type_col="cell_type_fine",
    color_map=color_map,
    label_order=cell_type_order,
)


In [None]:
fig = marker_plotter.visualize_markers(
    markers=[
        "IFNG",
        "TNF",
        "RORC",
        "CCR6",
        "TRDV2",
        "TRGV9",
        "KLRB1",
        "NCAM1",
        "XCL1",
        "GZMB",
        "PRF1",
        "CCR7",
        "SELL",
        "KLF2",
        "S1PR1",
    ],
    dtype="rna",
    ncols=4,
    use_default_plot=True,
    # marker_plot=False
    dotplot=False,
    vmax=None,
)


In [None]:
fig = marker_plotter.visualize_markers(
    markers=["CCR6", "CCR7", "CXCR3", "CCR4", "CD45RA", "CD45RO"],
    dtype="protein",
    ncols=3,
    use_default_plot=True,
    # marker_plot=False
    dotplot=False,
    # vmax=None,
    vmax=["p99", "p99", "p99", "p99", "p99", "p99"],
)


## Save annotations

In [None]:
mudata.mod["rna"] = mod_rna
mudata.mod["cite"] = mod_cite
mudata = md.MuData(
    {
        "rna": mod_rna,
        "cite": mod_cite,
    }
)

In [None]:
mod_rna.obs.cell_type_fine.value_counts()

In [None]:
mudata.obs

In [None]:
save_path = "data/anca/ustekinumab/R_ANCA_4PK4PB_ustekinumab_cd8emrm_annotated.h5mu"
mudata.write_h5mu(save_path)

## Prepare for celltypist

In [None]:
mod_rna


In [None]:
run_celltypist(mod_rna)


In [None]:
sc.pl.umap(
    mod_rna, color=["celltypist_cell_label_coarse", "celltypist_cell_label_fine"]
)