In [None]:
# Project description
PROJECT_NAME = "ANCA"
PATIENT_NAME = "P139"

import os
import math

# Import section
import logging
logging.basicConfig(level=logging.INFO)

import warnings
warnings.simplefilter("ignore", category=UserWarning)
warnings.simplefilter("ignore", category=FutureWarning)
warnings.simplefilter("ignore", category=DeprecationWarning)

import os
from glob import glob

from anndata import AnnData

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.cluster.vq import kmeans, vq

import scanpy as sc
import scirpy as ir

sc.logging.print_header()

import scvelo as scv

scv.logging.print_version()
scv.settings.verbosity = 3
scv.settings.presenter_view = True
scv.set_figure_params('scvelo')


BASE_DIR = os.getcwd()

DATA_DIR = os.path.join(BASE_DIR, "data")
PROJECT_CHECKPOINT_DIR = os.path.join(DATA_DIR, "checkpoints")
RAW_DATA_DIR = os.path.join(DATA_DIR, "raw")


# Checkpoint handling functions

def save_checkpoint(adata_obj, filename, overwrite=False):
    filename = os.path.join(PROJECT_CHECKPOINT_DIR, filename)
    if os.path.isfile(filename) and not overwrite:
        raise FileExistsError(f"File '{filename}' already exists")
    adata_obj.write_h5ad(filename)

def load_checkpoint(filename):
    filename = os.path.join(PROJECT_CHECKPOINT_DIR, filename)
    if not os.path.isfile(filename):
        raise FileNotFoundError(f"Cant find file '{filename}'")
    return sc.read_h5ad(filename)

def list_checkpoints():
    found_checkpoints = glob(os.path.join(PROJECT_CHECKPOINT_DIR, "*"))
    found_checkpoints = [os.path.split(filename)[1] for filename in found_checkpoints]
    print(f"Found {len(found_checkpoints)} checkpoint files in dir '{PROJECT_CHECKPOINT_DIR}'")
    return found_checkpoints


# Cluster Hashtag data
def get_hashtag_splitting_threshold(adata_obj, obs_name: str):
    data = adata_obj.obs[obs_name]

    codebook, _ = kmeans(data, 2)
    cluster_indices, _ = vq(data, codebook)

    # get cluster thresholds
    cluster0_th = min(data[cluster_indices == 0])
    cluster1_th = min(data[cluster_indices == 1])

    splitting_th = max(cluster1_th, cluster0_th)
    return splitting_th


def plot_hashtag_expression(adata_obj, obs_name: str, splitting_th: int = None, title: str = None, ax=None):

    ax = sns.histplot(
        adata_obj.obs,
        x=obs_name,
        bins=100,
        # hue="has_" + obs_name,
        ax=ax
    )
    if title is None:
        title = obs_name

    ax.set_title(f"{title}\nNormalized Expression Filter Threshold")
    ax.set_xlabel("log1p(expression)")
    ax.set_ylabel("frequency")

    if splitting_th is not None:
        ax.vlines(x = splitting_th, ymin=1,ymax=ax.get_ylim()[1], color="darkred")
    return ax


def keep_cells_above_threshold(adata_obj: AnnData, obs_name: str, splitting_th: int) -> AnnData:
    total_cells = len(adata_obj.obs)
    adata_obj = adata_obj[adata_obj.obs[adata_obj.obs[obs_name] >= splitting_th].index, :]
    print(f"{len(adata_obj.obs)} / {total_cells} cells kept ({round(len(adata_obj.obs) / total_cells * 100, 2)} %).")
    return adata_obj


def keep_cells_below_threshold(adata_obj: AnnData, obs_name: str, splitting_th: int) -> AnnData:
    total_cells = len(adata_obj.obs)
    adata_obj = adata_obj[adata_obj.obs[adata_obj.obs[obs_name] <= splitting_th].index, :]
    print(f"{len(adata_obj.obs)} / {total_cells} cells kept ({round(len(adata_obj.obs) / total_cells * 100, 2)} %).")
    return adata_obj

def exclude_cluster(adata_obj: AnnData, cluster: str, obs_name: str = "leiden") -> AnnData:
    total_cells = len(adata_obj.obs)
    adata_obj = adata_obj[adata_obj.obs[obs_name] != cluster, :]
    print(f"{len(adata_obj.obs)} / {total_cells} cells kept ({round(len(adata_obj.obs) / total_cells * 100, 2)} %).")
    return adata_obj

In [None]:
# Load 10X Single cell mnatrix with scanpy
adata = sc.read_10x_mtx(os.path.join(RAW_DATA_DIR, PATIENT_NAME) , cache=False, var_names='gene_symbols', gex_only=False)
adata.var_names_make_unique()
adata.layers["counts"] = adata.X.copy()

print(f"{len(adata)} cells in dataset")

In [None]:
protein = adata[:,
                (adata.var["feature_types"] == "Antibody Capture") & \
                (~adata.var["gene_ids"].isin(["HashB", "HashK"]))
               ].copy()

hashtags = adata[:,
                (adata.var["feature_types"] == "Antibody Capture") & \
                (adata.var["gene_ids"].isin(["HashB", "HashK"]))
               ].copy()

rna = adata[:, adata.var["feature_types"] == "Gene Expression"].copy()

### Hashtag preprocessing

In [None]:
ORGANS = {
    "hashtag_kidney": "HashK",
    "hashtag_blood": "HashB",
}

# Preprocessing
hashtags.layers["log1p"] = hashtags.X.copy()
hashtags.layers["log1p"] = np.log1p(hashtags.X).copy()

# Map expression to OBS data
for hashtag_obs, hashtag_id in ORGANS.items():
    # if hashtag_id in hashtags.var["gene_ids"]:
    if hashtags.var["gene_ids"].str.contains(hashtag_id).any():
        print(f"{hashtag_id} found in hashtags.")
        hashtags.obs[hashtag_obs] = hashtags[:, hashtags.var["gene_ids"] == hashtag_id].layers["log1p"].toarray()
    else:
        print(f"{hashtag_id} not in hashtags.")

In [None]:
# rna.obs = hashtags.obs.copy()
rna.obs = rna.obs.merge(hashtags.obs, left_index=True, right_index=True)

### CiteSeq Preprocessing

In [None]:
for adt_id in protein.var.index.tolist():
    protein.obs[adt_id] = np.log1p(
        protein[: ,protein.var.index == adt_id].X.toarray()
    )

In [None]:
# Merge OBS data to RNA
rna.obs = rna.obs.merge(protein.obs, left_index=True, right_index=True)

### Transcriptome processing

In [None]:
rna.var['mt'] = rna.var_names.str.startswith('MT-')
rna.var['ribo'] = rna.var_names.str.startswith('RPL') | rna.var_names.str.startswith('RPS')
sc.pp.calculate_qc_metrics(rna, qc_vars=['mt', "ribo"], percent_top=None, log1p=False, inplace=True)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize = (4, 8))
p1 = sc.pl.scatter(rna, x='total_counts', y='n_genes_by_counts', show=False, ax=ax1)
p2 = sc.pl.scatter(rna, x='total_counts', y='pct_counts_mt', show=False, ax=ax2)

In [None]:
# Filter mito genes by cutoff (%)
MITO_CUTOFF = 20

total_cell_count = len(rna)
rna = rna[rna.obs.pct_counts_mt < MITO_CUTOFF, :]

print(f"Filter by cutoff {MITO_CUTOFF}% out " \
      f"{total_cell_count - len(rna)}/{total_cell_count} cells by parameter" \
      f"'pct_counts_mt' ({round(len(rna) / total_cell_count * 100, 2)}%)")


print(f"Got a final count of {len(rna)} cells in " \
      f"dataset ({round(len(rna) / total_cell_count * 100, 2)}%)")

In [None]:
scv.pp.normalize_per_cell(rna)
scv.pp.filter_genes_dispersion(
    rna,
    min_mean=0.0125,
    max_mean=3,
    min_disp=0.5,
    subset=False
)

In [None]:
rna.raw = rna

In [None]:
sc.pp.scale(rna, max_value=10)

In [None]:
sc.tl.pca(rna, svd_solver='arpack')
sc.pp.neighbors(rna)
sc.tl.umap(rna)
sc.tl.leiden(rna)

In [None]:
# Save UMAP plot with leiden CLustering
plot = sc.pl.umap(rna,
                  color=["leiden"],
                  show = False,
                  frameon = False,
                  title="UMAP with leiden clustering")

fig = plot.get_figure()
fig.set_tight_layout(True)

In [None]:
filename = os.path.join(RAW_DATA_DIR, PATIENT_NAME, "filtered_contig_annotations.csv")

In [None]:
# Load TCR
tcr = ir.io.read_10x_vdj(path=filename)

# Insert TCR data into full adata
rna.obs = pd.DataFrame.merge(rna.obs, tcr.obs, left_index=True, right_index=True, how="left")

In [None]:
ir.tl.chain_qc(rna)

In [None]:
ir.pp.ir_dist(rna)
ir.tl.define_clonotypes(rna, receptor_arms="all", dual_ir="primary_only")

In [None]:
ir.tl.clonal_expansion(rna)

In [None]:
def make_unique_clone_id(adata_obj: AnnData, prefix):
    adata_obj.obs.loc[adata_obj.obs["clone_id"].isna(), "clone_id"] = None
    adata_obj.obs["clone_id"] = adata_obj.obs["clone_id"].astype(str)
    adata_obj.obs.loc[
        ~adata_obj.obs["clone_id"].isna(),
        "clone_id"
    ] = prefix + "-" + adata_obj.obs.loc[
        adata_obj.obs["clone_id"] != "nan",
        "clone_id"
    ]
    return adata_obj

In [None]:
adata = make_unique_clone_id(rna, PATIENT_NAME)

### Merge Single Cell Velocity data

In [None]:
scvelo_adata = scv.read_loom(os.path.join(RAW_DATA_DIR, PATIENT_NAME, "PatientAlignment.loom"))
scvelo_adata.var_names_make_unique()

# Rename indices
unique_index_prefix = scvelo_adata.obs.index[0].split(":")[0]
scvelo_adata.obs = scvelo_adata.obs.rename(index = lambda ind : ind.replace("x", "-1").replace(f"{unique_index_prefix}:", ""))

# Rename columns to avoid conflicts
scvelo_adata.obs = scvelo_adata.obs.rename(columns = {
    "_X": "scvelo_tsne_X",
    "_Y": "scvelo_tsne_Y",
    "Clusters": "scvelo_clusters",
})

In [None]:
# Find intersect barcodes
intersect_barcodes = scvelo_adata.obs.index.intersection(adata.obs.index)

# filter RNA dataset for
adata = adata[intersect_barcodes, :]

In [None]:
# Merge OBS data to RNA
adata.obs = adata.obs.merge(scvelo_adata.obs, how="left", left_index=True, right_index=True)
print(f"{len(adata)} final cells in dataset")

In [None]:
# Merge VAR data to RNA
adata.var = adata.var.merge(scvelo_adata.var, how="left", left_index=True, right_index=True)

In [None]:
# Filter scvelo_adata object for shape of filtered RNA data
scvelo_adata = scvelo_adata[intersect_barcodes, :]

# Check if same size
assert len(adata) == len(scvelo_adata)

In [None]:
#  Copy all layers to filtered RNA object
for layer_name in list(scvelo_adata.layers):
    print(f"Merge layer {layer_name}")
    adata.layers[layer_name] = scvelo_adata.layers[layer_name].copy()

### Isolate CD3+ cells

In [None]:
sc.pp.pca(adata, svd_solver='arpack', n_comps=40)
sc.pp.neighbors(adata, n_neighbors=40, n_pcs=40)
sc.tl.leiden(adata, resolution = 0.5)
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color=["CD4", "CD4_TotalSeqC", "leiden"])

In [None]:
adata = exclude_cluster(adata_obj=adata, cluster="2")
adata = exclude_cluster(adata_obj=adata, cluster="4")

In [None]:
# th = get_hashtag_splitting_threshold(adata, "CD8_TotalSeqC")
plot_hashtag_expression(adata, "CD8_TotalSeqC", splitting_th=0.1)
adata = keep_cells_below_threshold(adata, "CD8_TotalSeqC", splitting_th=0.1)

### Isolate CD4+ cells

In [None]:
sc.pp.pca(adata)
sc.pp.neighbors(adata)
sc.tl.leiden(adata)
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color=["CD4_TotalSeqC", "leiden"])

### Save checkpoint

In [None]:
save_checkpoint(
    adata_obj=adata,
    filename=os.path.join(PROJECT_CHECKPOINT_DIR, f"{PROJECT_NAME}-{PATIENT_NAME}-preprocessed.h5ad"),
    overwrite=True
)