In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
os.chdir("../../../../")

import anndata as ad
import matplotlib.pyplot as plt
import mudata as md
import muon as mu
import numpy as np
import scanpy as sc
from scipy import sparse

from utils.utils import *
from utils.plotting import *

sc.settings.verbosity = 0
sc.settings.set_figure_params(
    dpi=100,
    facecolor="white",
    # color_map="YlGnBu",
    frameon=False,
)
import matplotlib as mpl

sc.set_figure_params(dpi=100)
sns.set(style="white")

## Load data

In [None]:
mudata = md.read_h5mu("data/anca/exploratory/ANCA_27PK27PB_T_harmony_r_annotated.h5mu")
mod_rna = mudata.mod["rna"]
mod_cite = mudata.mod["cite"]
mudata

## Subset CD8EMRM

In [None]:
celltype_key = "cell_type"
mudata = mudata[mudata.mod["rna"].obs[celltype_key] == "CD8+ EM/RM"].copy()
mod_rna = mudata.mod["rna"]
mod_cite = mudata.mod["cite"]

## Preprocess data

In [None]:
print(mod_rna.shape)
sc.pp.filter_genes(mod_rna, min_cells=10)
print(mod_rna.shape)

In [None]:
# remove any patient that contains less than 2 cells
cells_per_patient = mod_rna.obs["patient"].value_counts()
patients_to_remove = cells_per_patient[cells_per_patient < 2].index.tolist()
mod_rna = mod_rna[~mod_rna.obs["patient"].isin(patients_to_remove)].copy()
mod_cite = mod_cite[mod_rna.obs_names, :].copy()

### Quality control

In [None]:
plot_qc(mod_rna, celltype_key)

In [None]:
# number of unique patients
len(mod_rna.obs["patient"].unique())

## Integrate data

In [None]:
log_normalize(mod_rna)
protein_clr(mod_cite)

In [None]:
model, rna_subset = integrate_with_totalvi(
    mod_rna,
    mod_cite,
    batch_key="patient",
    n_top_genes=4000,
    # continuous_covariate_keys=["pct_counts_mt", "total_counts"],
    # continuous_covariate_keys=["total_counts"],
    # continuous_covariate_batch_keys=["pct_counts_mt"],
    empirical_protein_background_prior=True,
)

In [None]:
plt.plot(model.history["elbo_train"], label="train")
plt.plot(model.history["elbo_validation"], label="val")
plt.title("Negative ELBO over training epochs")
# plt.ylim(500, 1500)
plt.legend()

In [None]:
rep_name = "X_totalvi"
mod_rna.obsm[rep_name] = model.get_latent_representation()
sc.pp.neighbors(mod_rna, use_rep=rep_name)
sc.tl.umap(mod_rna)

In [None]:
sc.pl.umap(
    mod_rna,
    color=["sample", "patient", "tissue"],
    # return_fig=True,
    # legend_loc="on data",
    # save="_leiden.png",
    wspace=0.8,
)

In [None]:
resolutions = np.linspace(0.4, 2.0, 17)
for res in resolutions:
    res = np.round(res, 1)
    print(f"Running Leiden clustering with resolution {res}")
    sc.tl.leiden(mod_rna, key_added=f"leiden_{res}", resolution=res)

In [None]:
plot_leiden_results(mod_rna, rep_name="")

In [None]:
mudata = md.MuData(
    {
        "rna": mod_rna,
        "cite": mod_cite,
    }
)
mudata.write_h5mu("data/anca/exploratory/ANCA_27PK27PB_cd8emrm_totalvi.h5mu")

In [None]:
cluster_key = None
cluster_key = "leiden_1.0"

In [None]:
sc.pl.umap(mod_rna, color=cluster_key, legend_loc="on data", wspace=0.8)

In [None]:
all_marker_results, filtered_marker_results, best_markers = run_de_pipeline(
    mod_rna,
    cluster_key,
    mod="rna",
    top_n=5,
    min_expression=0.1,
)

In [None]:
marker_df = pd.read_excel("data/markers/T cell markers.xlsx")
marker_df = marker_df.set_index("Cell type")

# extract positive markers
marker_db = {}
for index, row in marker_df.iterrows():
    pos_markers_1 = (
        row["positive markers 1"].replace(" ", "").split(",")
        if row["positive markers 1"] is not np.nan
        else []
    )
    pos_markers_2 = (
        row["positive markers 2"].replace(" ", "").split(",")
        if row["positive markers 2"] is not np.nan
        else []
    )
    marker_db[index] = list(set(pos_markers_1 + pos_markers_2))

marker_plotter = MarkerPlotter(mod_rna, mod_cite, marker_db, cluster_key=cluster_key)

In [None]:
fig = marker_plotter.visualize_markers(
    markers=[
        "CD8A",
        "CD4",
        "CD3E",
    ],
    dtype="rna",
    ncols=3,
    use_default_plot=True,
    # marker_plot=False
    # dotplot=False,
)

In [None]:
fig = marker_plotter.visualize_markers(
    markers=[
        "IL5",
        "IL13",
        "IL9",
        "IL21",
        "IL22",
        "IL2",
    ],
    dtype="rna",
    ncols=3,
    use_default_plot=True,
    # marker_plot=False
    # dotplot=False,
)

In [None]:
for cell_type in marker_db.keys():
    fig = marker_plotter.visualize_markers(
        cell_type=cell_type,
        dtype="rna",
        ncols=3,
        use_default_plot=True,
        # marker_plot=False
        # dotplot=False,
    )

## Protein markers

In [None]:
mod_cite.obs[cluster_key] = mod_rna.obs[cluster_key]
mod_cite.obsm["X_umap"] = mod_rna.obsm["X_umap"]

In [None]:
for cell_type in marker_db.keys():
    fig = marker_plotter.visualize_markers(
        cell_type=cell_type,
        dtype="protein",
        ncols=3,
        use_default_plot=True,
        # marker_plot=False
        # dotplot=False,
    )