In [None]:
import os

import pandas as pd
import numpy as np
from glob import glob

import scanpy as sc

import matplotlib.pyplot as plt
import pandas as pd
from math import pi

import phate
import meld

BASE_DIR = os.getcwd()
DATA_DIR = os.path.join(BASE_DIR, "data")

CHECKPOINT_DIR = os.path.join(DATA_DIR, "checkpoints")

PROCESSED_DIR = os.path.join(DATA_DIR, "processed")
PDF_DIR = os.path.join(PROCESSED_DIR, "pdf")
NOTEBOOK_DIR = os.path.join(BASE_DIR, "notebooks")

RAW_DATA_DIR = os.path.join(DATA_DIR, "raw")

PROJECT_NAME = "CropSeq-20-14"

def sfile(filename):
    _fname = os.path.join(PDF_DIR, f"{PROJECT_NAME}_merged_{filename}")
    print(f"File save at '{_fname}'")
    return _fname


# Checkpoint handling functions

def save_checkpoint(adata_obj, filename, overwrite=False):
    filename = os.path.join(CHECKPOINT_DIR, filename)
    if os.path.isfile(filename) and not overwrite:
        raise FileExistsError(f"File '{filename}' already exists")
    adata_obj.write_h5ad(filename)

def load_checkpoint(filename):
    filename = os.path.join(CHECKPOINT_DIR, filename)
    if not os.path.isfile(filename):
        raise FileNotFoundError(f"Cant find file '{filename}'")
    return sc.read_h5ad(filename)

def list_checkpoints():
    found_checkpoints = glob(os.path.join(CHECKPOINT_DIR, "*"))
    found_checkpoints = [os.path.split(filename)[1] for filename in found_checkpoints]
    print(f"Found {len(found_checkpoints)} checkpoint files in dir '{CHECKPOINT_DIR}'")
    return found_checkpoints

### Load checkpoint

In [None]:
adata_concat = load_checkpoint("Cropseq_20_14_revision3_murine__seurat_Singlets.h5ad")

In [None]:
sc.tl.pca(adata_concat)

In [None]:
phate_op = phate.PHATE(n_jobs=-1)
data_phate = phate_op.fit_transform(adata_concat.obsm["X_pca"])

In [None]:
adata_concat.obsm["X_phate"] =  data_phate

In [None]:
sc.pl.scatter(adata_concat, basis="phate", color="seurat_clusters", size=15)

In [None]:
sample_densities = meld.MELD().fit_transform(adata_concat.obsm["X_pca"], adata_concat.obs["gRNA_group"])

In [None]:
sample_densities = sample_densities.rename(columns={i: f"meld_density_{i}" for i in sample_densities.columns})

In [None]:
sample_densities.index = adata_concat.obs.index

In [None]:
adata_concat.obs = adata_concat.obs.merge(sample_densities, left_index=True, right_index=True)

In [None]:
sc.pl.scatter(adata_concat, basis="phate", color="meld_density_control", size=15, color_map="coolwarm")

In [None]:
# Normalize densities to calculate sample likelihoods
sample_likelihoods = meld.utils.normalize_densities(sample_densities)

In [None]:
sample_likelihoods = sample_likelihoods.rename(columns={i: f"meld_likelihood_{'_'.join(i.split('_')[2:])}" for i in sample_densities.columns})

In [None]:
adata_concat.obs = adata_concat.obs.merge(sample_likelihoods, left_index=True, right_index=True)

In [None]:
def meld_df_from_target(df, target):
    return pd.DataFrame({
        "target": [target] * len(df),
        "meld_likelihood": df[f"meld_likelihood_{target}"].to_list(),
        "meld_density": df[f"meld_density_{target}"].to_list(),
    })


In [None]:
def get_likelyhood_median(no_guide_cluster, target):
    clusters = None

    if no_guide_cluster == "Th17":
        clusters = [1, 5]
    if no_guide_cluster == "Th1":
        clusters = [3]
    if no_guide_cluster == "Treg":
        clusters = [0, 6]

    assert clusters is not None, "Failed"

    return np.mean(adata_concat.obs[
        (adata_concat.obs["gRNA_group"] == "control") &
        (adata_concat.obs["seurat_clusters"].isin(clusters))
    ][f"meld_likelihood_{target}"])

In [None]:
for guide_target in list(adata_concat.obs[~adata_concat.obs["gRNA_group"].isin(["control"])]["gRNA_group"].unique()):
    # Set data
    df = pd.DataFrame({
        'group': [guide_target,"Control"],
        'Th17': [get_likelyhood_median("Th17", guide_target), get_likelyhood_median("Th17", "control")],
        'Th1': [get_likelyhood_median("Th1", guide_target), get_likelyhood_median("Th1", "control")],
        'Treg': [get_likelyhood_median("Treg", guide_target), get_likelyhood_median("Treg", "control")],
    })

    # Spider chart
    categories=list(df)[1:]
    N = len(categories)
    angles = [n / float(N) * 2 * pi for n in range(N)]
    angles += angles[:1]
    fig = plt.figure()
    ax = fig.add_subplot(111, polar=True)
    ax.set_theta_offset(pi / 2)
    ax.set_theta_direction(-1)
    ax.set_xticks(angles[:-1], categories)
    max_value = max(max(df["Th17"]), max(df["Th1"]), max(df["Treg"]))
    min_value = min(min(df["Th17"]), min(df["Th1"]), min(df["Treg"])) * 0.95
    ax.set_rlabel_position(0)
    ax.set_yticks(
        [
            min_value,
            min_value + ((max_value - min_value) / 3),
            min_value + ((max_value - min_value) / 3 * 2),
            max_value
        ],
        [
            str(np.round(min_value, 4)),
            str(np.round(min_value + ((max_value - min_value) / 3), 4)),
            str(np.round(min_value + ((max_value - min_value) / 3 * 2), 4)),
            str(np.round(max_value, 4))
        ],
        color="grey",
        size=7
    )
    ax.set_ylim(min_value,max_value)
    val=df.loc[0].drop('group').values.flatten().tolist()
    val += val[:1]
    ax.plot(angles, val, linewidth=1, linestyle='solid', label=guide_target)
    ax.fill(angles, val, 'b', alpha=0.1)
    val=df.loc[1].drop('group').values.flatten().tolist()
    val += val[:1]
    ax.plot(angles, val, linewidth=1, linestyle='solid', label="Control")
    ax.fill(angles, val, 'r', alpha=0.1)
    plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
    ax.set_title(f"MELD likelyhook {guide_target}")
    fig.savefig(sfile(f"target-{guide_target}-meld-likelihood-spiderchart.pdf"), transparent=True)

In [None]:
df = None

for guide_target in list(adata_concat.obs[~adata_concat.obs["gRNA_group"].isin(["control"])]["gRNA_group"].unique()):

    buffer_df = pd.DataFrame({
        'group': [guide_target],
        'Th17': [get_likelyhood_median("Th17", guide_target)],
        'Th1': [get_likelyhood_median("Th1", guide_target)],
        'Treg': [get_likelyhood_median("Treg", guide_target)],
    })

    control_buffer_df = pd.DataFrame({
        'group': ["Control"],
        'Th17': [get_likelyhood_median("Th17", "control")],
        'Th1': [get_likelyhood_median("Th1", "control")],
        'Treg': [get_likelyhood_median("Treg", "control")],
    })

    if df is None:
        df = pd.concat([control_buffer_df, buffer_df])
    else:
        df = pd.concat([df, buffer_df])

df.to_csv(sfile("median-meld-likelyhook.csv"))