In [None]:
import os

import pandas as pd
import numpy as np
from glob import glob

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator

from adjustText import adjust_text

import scanpy as sc

BASE_DIR = os.getcwd()
DATA_DIR = os.path.join(BASE_DIR, "data")

CHECKPOINT_DIR = os.path.join(DATA_DIR, "checkpoints")

PROCESSED_DIR = os.path.join(DATA_DIR, "processed")
PDF_DIR = os.path.join(PROCESSED_DIR, "pdf")
NOTEBOOK_DIR = os.path.join(BASE_DIR, "notebooks")

RAW_DATA_DIR = os.path.join(DATA_DIR, "raw")

PROJECT_NAME = "CropSeq-23-2"

PROJECTS = [
    "Th17-1",
    "Th17-2",
    "Th1",
]

In [None]:
def sfile(filename):
    _fname = os.path.join(PDF_DIR, f"{PROJECT_NAME}_merged_{filename}")
    print(f"File save at '{_fname}'")
    return _fname

In [None]:
# Checkpoint handling functions

def save_checkpoint(adata_obj, filename, overwrite=False):
    filename = os.path.join(CHECKPOINT_DIR, filename)
    if os.path.isfile(filename) and not overwrite:
        raise FileExistsError(f"File '{filename}' already exists")
    adata_obj.write_h5ad(filename)

def load_checkpoint(filename):
    filename = os.path.join(CHECKPOINT_DIR, filename)
    if not os.path.isfile(filename):
        raise FileNotFoundError(f"Cant find file '{filename}'")
    return sc.read_h5ad(filename)

def list_checkpoints():
    found_checkpoints = glob(os.path.join(CHECKPOINT_DIR, "*"))
    found_checkpoints = [os.path.split(filename)[1] for filename in found_checkpoints]
    print(f"Found {len(found_checkpoints)} checkpoint files in dir '{CHECKPOINT_DIR}'")
    return found_checkpoints

In [None]:
# Load checkpoints to AnnData Objects
adata_objs = {}

for patient in PROJECTS:
    adata_objs[patient] = load_checkpoint(f"{PROJECT_NAME}-{patient}-preprocessed.h5ad")

In [None]:
# Concat AnnData Objects
adata_concat = list(adata_objs.values())[0].concatenate(
    list(adata_objs.values())[1:],
    batch_categories=PROJECTS
)

del adata_objs

In [None]:
# Rerun concat dataset PCA and neighbor analysis
sc.pp.pca(adata_concat, svd_solver='arpack', n_comps=40)
sc.pp.neighbors(adata_concat, n_neighbors=40, n_pcs=40)
sc.tl.leiden(adata_concat, resolution = 0.5)
sc.tl.umap(adata_concat)

In [None]:
# Show batch Effekt
plot = sc.pl.umap(
    adata_concat,
    color=["batch", "leiden"],
    show = False,
    frameon = False,
    title=["Batches", "UMAP with leiden clustering"]
)

In [None]:
# Run harmony
import scanpy.external as sce
sce.pp.harmony_integrate(adata_concat, 'batch', adjusted_basis = "X_pca_harmony")
assert 'X_pca_harmony' in adata_concat.obsm

In [None]:
sc.pp.neighbors(adata_concat, use_rep = "X_pca_harmony")
sc.tl.umap(adata_concat)
sc.tl.leiden(adata_concat, resolution = 0.5, key_added="leiden")

# Show batch Effekt
plot = sc.pl.umap(
    adata_concat,
    color=["batch", "leiden"],
    show = False,
    frameon = False,
    title=["Batches", "UMAP with leiden clustering"]
)

In [None]:
sc.pp.highly_variable_genes(adata_concat, subset=True)

In [None]:
adata_concat.obs["is_control"] = np.where(adata_concat.obs["guide_target"] == "control", "control", "not control")

fig, ax = plt.subplots(figsize = (8, 8))

ax = sc.pl.umap(
    adata_concat,
    color=["is_control"],
    palette={"control": "#e52521", "not control": "lightgrey"},
    ax=ax,
    size=[30 if i == "control" else 10 for i in adata_concat.obs["is_control"].to_list()],
    show=False,
)

ax.set_title("UMAP with highlighted control cells")

# Set ticks
ax.xaxis.set_major_locator(MultipleLocator(3))
ax.yaxis.set_major_locator(MultipleLocator(3))

# Remove top and right spine
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

fig.savefig(sfile("umap-highlight-control-scatter.pdf"), transparent=True)

In [None]:
adata_concat.obs["is_no_guide"] = np.where(adata_concat.obs["guide_target"] == "no_guide", "no guide", "guide")

fig, ax = plt.subplots(figsize = (8, 8))

ax = sc.pl.umap(
    adata_concat,
    color=["is_no_guide"],
    palette={"no guide": "#0d9488", "guide": "lightgrey"},
    ax=ax,
    size=[30 if i == "no guide" else 10 for i in adata_concat.obs["is_no_guide"].to_list()],
    show=False,
)

ax.set_title("UMAP with highlighted no guide cells")

# Set ticks
ax.xaxis.set_major_locator(MultipleLocator(3))
ax.yaxis.set_major_locator(MultipleLocator(3))

# Remove top and right spine
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

fig.savefig(sfile("umap-highlight-no-guide-scatter.pdf"), transparent=True)

In [None]:
fig, ax = plt.subplots(figsize = (8, 8))

ax = sc.pl.umap(
    adata_concat,
    color=["leiden"],
    ax=ax,
    show=False,
)

ax.set_title("UMAP with leiden clustering")

# Set ticks
ax.xaxis.set_major_locator(MultipleLocator(3))
ax.yaxis.set_major_locator(MultipleLocator(3))

# Remove top and right spine
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

fig.savefig(sfile("umap-clustering-scatter.pdf"), transparent=True)

In [None]:
data = {
    "target": [],
    "umap_x": [],
    "umap_y": [],
    "count": [],
    "color": [],
}

target_groups = {
    "Tgfbr1": "group_4",
    "Tbx21": "group_1",
    "Cd40lg": "group_5",
    "Il17a": "group_2",
    "Tgfbr2": "group_4",
    "Il6ra": "group_4",
    "control": "group_3",
    "Il23r": "group_4",
    "Jak2": "group_6",
    "Myd88": "group_6",
    "Cebpb": "group_1",
    "Stat3": "group_1",
    "Syk": "group_6",
    "Runx1": "group_1",
    "Ccr6": "group_4",
    "Traf6": "group_6",
    "Il1r1": "group_4",
    "Stat6": "group_1",
    "Tlr4": "group_6",
    "Nfatc2": "group_1",
    "Ahr": "group_1",
    "Socs3": "group_6",
    "Socs1": "group_6",
    "Il12rb": "group_4",
    "Rorc": "group_1",
    "Nfkb1": "group_1",
    "Il7r": "group_4",
    "Irf4": "group_1",
    "Jak1": "group_6",
    "no_guide": "group_7",
}

group_colors = {
    "group_1": "#89cfbb",
    "group_2": "#c59ac4",
    "group_3": "#e52521",
    "group_4": "#b4c79c",
    "group_5": "#a99bc8",
    "group_6": "#e4c982",
    "group_7": "#0d9488",
}

for guide in adata_concat.obs["guide_target"].unique():
    umap_coord = np.median(adata_concat[adata_concat.obs[adata_concat.obs["guide_target"] == guide].index,:].obsm["X_umap"], axis=0)
    count = len(adata_concat.obs[adata_concat.obs["guide_target"] == guide])
    data["target"].append(guide)
    data["umap_x"].append(umap_coord[0])
    data["umap_y"].append(umap_coord[1])
    data["count"].append(count)
    data["color"].append(target_groups[guide])


data = pd.DataFrame(data)

fig, ax = plt.subplots(figsize = (8, 8))
ax = sns.scatterplot(
    data,
    x="umap_x",
    y="umap_y",
    size="count",
    hue="color",
    sizes=(1, 350),
    palette=group_colors,
    ax=ax,
)

ax.set_title("UMAP Medians per gRNA group")

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

all_labels = [(ax.text(point["umap_x"], point["umap_y"], point["target"])) for i, point in data.iterrows()]
adjust_text(all_labels, rrowprops=dict(arrowstyle='-', color='black'))

sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))


fig.savefig(sfile("umap-median-scatter.pdf"), transparent=True)

In [None]:
fig, ax = plt.subplots(figsize = (8, 8))

ax = sc.pl.umap(
    adata_concat,
    color=["no_guide_cluster"],
    ax=ax,
    size=[30 if i != "NA" else 10 for i in adata_concat.obs["no_guide_cluster"].to_list()],
    show=False,
    alpha=0.7,
)

ax.set_title("UMAP with leiden clustering")

# Set ticks
ax.xaxis.set_major_locator(MultipleLocator(3))
ax.yaxis.set_major_locator(MultipleLocator(3))

# Remove top and right spine
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

fig.savefig(sfile("umap-no-guide.pdf"), transparent=True)

In [None]:
data = {
    "target": [],
    "umap_x": [],
    "umap_y": [],
    "count": [],
    "color": [],
}

target_groups = {
    "Tgfbr1": "group_4",
    "Tbx21": "group_1",
    "Cd40lg": "group_5",
    "Il17a": "group_2",
    "Tgfbr2": "group_4",
    "Il6ra": "group_4",
    "control": "group_3",
    "Il23r": "group_4",
    "Jak2": "group_6",
    "Myd88": "group_6",
    "Cebpb": "group_1",
    "Stat3": "group_1",
    "Syk": "group_6",
    "Runx1": "group_1",
    "Ccr6": "group_4",
    "Traf6": "group_6",
    "Il1r1": "group_4",
    "Stat6": "group_1",
    "Tlr4": "group_6",
    "Nfatc2": "group_1",
    "Ahr": "group_1",
    "Socs3": "group_6",
    "Socs1": "group_6",
    "Il12rb": "group_4",
    "Rorc": "group_1",
    "Nfkb1": "group_1",
    "Il7r": "group_4",
    "Irf4": "group_1",
    "Jak1": "group_6",
    "No guide Th1": "group_7",
    "No guide Th17": "group_7",
    "No guide Treg": "group_7",
    "No guide": "group_7",
}

group_colors = {
    "group_1": "#89cfbb",
    "group_2": "#c59ac4",
    "group_3": "#e52521",
    "group_4": "#b4c79c",
    "group_5": "#a99bc8",
    "group_6": "#e4c982",
    "group_7": "#0d9488",
}

SUM_FUNC = np.median

for guide in adata_concat.obs["guide_target"].unique():

    if guide == "no_guide":
        # filter for no_guide_cluster
        for item in ["Th1", "Th17", "Treg"]:
            umap_coord = SUM_FUNC(
                adata_concat[adata_concat.obs[
                    (adata_concat.obs["guide_target"] == "no_guide") & (adata_concat.obs["no_guide_cluster"] == item)
                ].index,:].obsm["X_umap"],
                axis=0
            )
            count = len(adata_concat.obs[(adata_concat.obs["guide_target"] == "no_guide") & (adata_concat.obs["no_guide_cluster"] == item)])

            data["target"].append(f"No guide {item}")
            data["umap_x"].append(umap_coord[0])
            data["umap_y"].append(umap_coord[1])
            data["count"].append(count)
            data["color"].append(target_groups[f"No guide {item}"])


        # Include rest of no guide cluster
        umap_coord = SUM_FUNC(
            adata_concat[adata_concat.obs[
                (adata_concat.obs["guide_target"] == "no_guide") & (~adata_concat.obs["no_guide_cluster"].isin(["Th1", "Th17", "Treg"]))
            ].index,:].obsm["X_umap"],
            axis=0
        )
        count = len(adata_concat.obs[(adata_concat.obs["guide_target"] == "no_guide") & (~adata_concat.obs["no_guide_cluster"].isin(["Th1", "Th17", "Treg"]))])

        data["target"].append("No guide")
        data["umap_x"].append(umap_coord[0])
        data["umap_y"].append(umap_coord[1])
        data["count"].append(count)
        data["color"].append(target_groups["No guide"])


    else:
        umap_coord = SUM_FUNC(adata_concat[adata_concat.obs[adata_concat.obs["guide_target"] == guide].index,:].obsm["X_umap"], axis=0)
        count = len(adata_concat.obs[adata_concat.obs["guide_target"] == guide])

        data["target"].append(guide)
        data["umap_x"].append(umap_coord[0])
        data["umap_y"].append(umap_coord[1])
        data["count"].append(count)
        data["color"].append(target_groups[guide])


data = pd.DataFrame(data)

fig, ax = plt.subplots(figsize = (8, 8))
ax = sns.scatterplot(
    data,
    x="umap_x",
    y="umap_y",
    size="count",
    hue="color",
    sizes=(1, 350),
    palette=group_colors,
    ax=ax,
    alpha=0.8,
)

ax.set_title("UMAP Medians per gRNA group")

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

all_labels = [(ax.text(point["umap_x"], point["umap_y"], point["target"])) for i, point in data.iterrows()]
adjust_text(all_labels, arrowprops=dict(arrowstyle='-', color='gray', alpha=0.8, lw=0.5))

sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))


fig.savefig(sfile("umap-median-scatter-no-guide-subcluster.pdf"), transparent=True)

In [None]:
all_cluster_names = list(adata_concat.obs[~adata_concat.obs["no_guide_cluster"].isna()]["no_guide_cluster"].unique())

row_list = []

def get_distance_control(target):
    return np.sqrt(np.sum(np.power((np.mean(adata_concat[adata_concat.obs[adata_concat.obs["guide_target"] == target].index,:].raw.X, axis=0) - \
                                 np.mean(adata_concat[adata_concat.obs[adata_concat.obs["guide_target"] == "control"].index,:].raw.X, axis=0)), 2)))


def get_distance_no_guide(target):
    return np.sqrt(np.sum(np.power((np.mean(adata_concat[adata_concat.obs[adata_concat.obs["guide_target"] == target].index,:].raw.X, axis=0) - \
                                 np.mean(adata_concat[adata_concat.obs[adata_concat.obs["guide_target"] == "no_guide"].index,:].raw.X, axis=0)), 2)))


def get_distance_subset(target, subset):
    return np.sqrt(np.sum(np.power((np.mean(adata_concat[adata_concat.obs[adata_concat.obs["guide_target"] == target].index,:].raw.X, axis=0) - \
        np.mean(adata_concat[adata_concat.obs[(adata_concat.obs["guide_target"] == "no_guide") & \
                                             (adata_concat.obs["no_guide_cluster"] == subset)].index,:].raw.X, axis=0)), 2)))

for target in adata_concat.obs["guide_target"].unique():

        row_list.append({
            **{
                'group': target,
            },
            **dict([(cluster, get_distance_subset(target, cluster)) for cluster in all_cluster_names]),
            **{
                "Control": get_distance_control(target),
                "ControlNoGuide": get_distance_no_guide(target),
            },
        })

In [None]:
pd.DataFrame(row_list).to_csv(sfile("euclidian-distances-allSubclusters.csv"))