In [None]:
PROJECT_NAME = "CropSeq-23-1"
SAMPLE = "Th17-2"


import json
import os
import re

import pandas as pd
import numpy as np
from glob import glob

import seaborn as sns
import matplotlib.pyplot as plt

import scanpy as sc
from scipy import stats
import scvelo as scv
import scirpy as ir

BASE_DIR = os.getcwd()
DATA_DIR = os.path.join(BASE_DIR, "data")

CHECKPOINT_DIR = os.path.join(DATA_DIR, "checkpoints")

PROCESSED_DIR = os.path.join(DATA_DIR, "processed")
PDF_DIR = os.path.join(PROCESSED_DIR, "pdf")
NOTEBOOK_DIR = os.path.join(BASE_DIR, "notebooks")

RAW_DATA_DIR = os.path.join(DATA_DIR, "raw")


def sfile(filename):
    _fname = os.path.join(PDF_DIR, f"{PROJECT_NAME}_{SAMPLE}_{filename}")
    print(f"File save at '{_fname}'")
    return _fname

# Checkpoint handling functions

def save_checkpoint(adata_obj, filename, overwrite=False):
    filename = os.path.join(CHECKPOINT_DIR, filename)
    if os.path.isfile(filename) and not overwrite:
        raise FileExistsError(f"File '{filename}' already exists")
    adata_obj.write_h5ad(filename)

def load_checkpoint(filename):
    filename = os.path.join(CHECKPOINT_DIR, filename)
    if not os.path.isfile(filename):
        raise FileNotFoundError(f"Cant find file '{filename}'")
    return sc.read_h5ad(filename)

def list_checkpoints():
    found_checkpoints = glob(os.path.join(CHECKPOINT_DIR, "*"))
    found_checkpoints = [os.path.split(filename)[1] for filename in found_checkpoints]
    print(f"Found {len(found_checkpoints)} checkpoint files in dir '{CHECKPOINT_DIR}'")
    return found_checkpoints


# CRISPR library loader

def load_cells_per_protospacer(dirname: str) -> dict:

    # Load crispr screen data
    crispr_data = json.load(open(os.path.join(dirname, "cells_per_protospacer.json"), "r"))
    cell_dict = {}

    for guide, barcodes in crispr_data.items():
        for barcode in barcodes:
            if barcode not in cell_dict:
                cell_dict[barcode] = [guide]
            else:
                if guide not in cell_dict[barcode]:
                    cell_dict[barcode].append(guide)

    return cell_dict


def load_protospacer_per_cell(dirname: str) -> dict:
    return pd.read_csv(os.path.join(dirname, "protospacer_calls_per_cell.csv"))

### Load raw data

In [None]:
adata = sc.read_10x_mtx(os.path.join(RAW_DATA_DIR, SAMPLE))

In [None]:
crispr_data = load_protospacer_per_cell(os.path.join(RAW_DATA_DIR, SAMPLE))

In [None]:
THRESHOLD_PREC = 0.05

# Corrected UMI counts
crispr_data_dict = crispr_data.to_dict()
crispr_data_dict["corr_num_umis"] = {}
crispr_data_dict["corr_num_features"] = {}
crispr_data_dict["corr_feature_call"] = {}

for key, value in crispr_data_dict["num_umis"].items():
    if crispr_data_dict["num_features"][key] == 1:
        crispr_data_dict["corr_num_umis"][key] = value
        crispr_data_dict["corr_num_features"][key] = 1
        crispr_data_dict["corr_feature_call"][key] = crispr_data_dict["feature_call"][key]
    else:
        umi_counts = [int(i) for i in value.split("|")]
        features = crispr_data_dict["feature_call"][key].split("|")

        new_umi_counts = []
        new_feature_calls = []

        adapt_th = np.sum(umi_counts) * THRESHOLD_PREC

        for index, item in enumerate(umi_counts):
            if item > adapt_th:
                new_umi_counts.append(umi_counts[index])
                new_feature_calls.append(features[index])

        crispr_data_dict["corr_num_umis"][key] = "|".join([str(i) for i in new_umi_counts])
        crispr_data_dict["corr_num_features"][key] = len(new_umi_counts)
        crispr_data_dict["corr_feature_call"][key] = "|".join(new_feature_calls)

crispr_data = pd.DataFrame(crispr_data_dict)

In [None]:
list(crispr_data["corr_num_features"].value_counts())

In [None]:
total_cells = len(adata)

multi_guide = [total_cells - sum(list(crispr_data["corr_num_features"].value_counts())), *list(crispr_data["corr_num_features"].value_counts())]

assert sum(multi_guide) == total_cells

In [None]:
fig, ax = plt.subplots(figsize = (7, 7))

ax = sns.barplot(y=multi_guide, x=list(range(len(multi_guide))), color="skyblue", ax=ax)
ax = ax.set(xlabel="Number of unique guide gRNA per cell", ylabel="Cell count", title="gRNA distribution")

fig.savefig(sfile("guide-frequency-barchart.pdf"), transparent=True)

In [None]:
colors = plt.get_cmap('Blues')(np.linspace(0.8, 0.3, 3))

fig, ax = plt.subplots()
_ = ax.pie(
    [multi_guide[0], multi_guide[1], sum(multi_guide[2:])],
    labels=["none", "single", "multi"],
    colors=colors,
    autopct="%1.1f%%",
    startangle=90,
)

ax = ax.set(title="gRNA distribution")

fig.savefig(sfile("guide-frequency-piechart.pdf"), transparent=True)

In [None]:
# Isolate guide singlets
crispr_data = crispr_data[crispr_data["corr_num_features"] == 1]

In [None]:
# Remove guide reads from count data
adata = adata[:,~adata.var["gene_ids"].str.match("gRNA_(.*)_gene")]

In [None]:
# Remove none-singlet cells from datasets
adata = adata[crispr_data["cell_barcode"].values,:]

In [None]:
def split_guide_name(name: str) -> tuple:

    match = re.match(r"^gRNA_(?P<target>[a-zA-Z0-9\-]*)_(?P<version>[0-9]{1,2})_capture$", name)

    if match is None:
        raise ValueError(f"Failed to get guide from {name}.")

    return match.groups()

In [None]:
crispr_data.index = crispr_data["cell_barcode"]

guide_adata = adata.copy()

guide_adata.obs = adata.obs.merge(crispr_data, left_index=True, right_index=True)

guide_adata.obs = guide_adata.obs.drop("cell_barcode", axis=1)
guide_adata.obs = guide_adata.obs.drop("corr_num_features", axis=1)

guide_adata.obs["guide_name"] = guide_adata.obs["corr_feature_call"]
guide_adata.obs["guide_num_umis"] = guide_adata.obs["corr_num_umis"]
guide_adata.obs = guide_adata.obs.drop("corr_feature_call", axis=1)
guide_adata.obs = guide_adata.obs.drop("corr_num_umis", axis=1)

# Convert type
guide_adata.obs["guide_num_umis"] = guide_adata.obs["guide_num_umis"].astype(dtype="int")

In [None]:
guide_adata.obs["guide_target"] = pd.Series(dtype="str")
guide_adata.obs["guide_version"] = pd.Series(dtype="str")

for cell in range(len(guide_adata.obs)):

    guide_name = guide_adata.obs.loc[guide_adata.obs.index[cell], "guide_name"]

    target, version = split_guide_name(guide_name)

    guide_adata.obs.loc[guide_adata.obs.index[cell], "guide_target"] = target
    guide_adata.obs.loc[guide_adata.obs.index[cell], "guide_version"] = version


In [None]:
guide_adata.var['mt'] = guide_adata.var_names.str.startswith('mt-')
guide_adata.var['ribo'] = guide_adata.var_names.str.startswith('Rpl') | guide_adata.var_names.str.startswith('Rps')
sc.pp.calculate_qc_metrics(guide_adata, qc_vars=['mt', "ribo"], percent_top=None, log1p=False, inplace=True)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize = (4, 8))
p1 = sc.pl.scatter(guide_adata, x='total_counts', y='n_genes_by_counts', show=False, ax=ax1)
p2 = sc.pl.scatter(guide_adata, x='total_counts', y='pct_counts_mt', show=False, ax=ax2)

In [None]:
# Filter mito genes by cutoff (%)
MITO_CUTOFF = 5

total_cell_count = len(guide_adata)
guide_adata = guide_adata[guide_adata.obs.pct_counts_mt < MITO_CUTOFF, :]

print(f"Filter by cutoff {MITO_CUTOFF}% out " \
      f"{total_cell_count - len(guide_adata)}/{total_cell_count} cells by parameter" \
      f"'pct_counts_mt' ({round(len(guide_adata) / total_cell_count * 100, 2)}%)")


print(f"Got a final count of {len(guide_adata)} cells in " \
      f"dataset ({round(len(guide_adata) / total_cell_count * 100, 2)}%)")

In [None]:
scv.pp.normalize_per_cell(guide_adata)
scv.pp.filter_genes_dispersion(
    guide_adata,
    min_mean=0.0125,
    max_mean=3,
    min_disp=0.5,
    subset=False
)

In [None]:
guide_adata.raw = guide_adata
sc.pp.regress_out(guide_adata, ['total_counts'])
sc.pp.scale(guide_adata, max_value=10)

In [None]:
sc.pp.pca(guide_adata)
sc.pp.neighbors(guide_adata)
sc.tl.leiden(guide_adata, resolution=0.5)
sc.tl.umap(guide_adata)

In [None]:
# Show batch Effekt
plot = sc.pl.umap(
    guide_adata,
    color=["leiden"],
    show = False,
    frameon = False,
    title=["UMAP with leiden clustering"]
)

### Load TCR data

In [None]:
filename = os.path.join(RAW_DATA_DIR, SAMPLE, "filtered_contig_annotations.csv")

In [None]:
# Load TCR
tcr = ir.io.read_10x_vdj(path=filename)

# Insert TCR data into full adata
guide_adata.obs = pd.DataFrame.merge(guide_adata.obs, tcr.obs, left_index=True, right_index=True, how="left")

# QC
ir.tl.chain_qc(guide_adata)
ax = ir.pl.group_abundance(guide_adata, groupby="receptor_subtype", target_col="leiden")

In [None]:
ax = ir.pl.group_abundance(guide_adata, groupby="chain_pairing", target_col="leiden")

In [None]:
ax = sc.pl.umap(guide_adata, color="chain_pairing", groups="single pair")

In [None]:
ir.pp.ir_dist(guide_adata)
ir.tl.define_clonotypes(guide_adata, receptor_arms="all", dual_ir="primary_only")

ir.tl.clonal_expansion(guide_adata)
sc.pl.umap(guide_adata, color="clonal_expansion")

In [None]:
def make_unique_clone_id(adata_obj, prefix):
    adata_obj.obs.loc[adata_obj.obs["clone_id"].isna(), "clone_id"] = None
    adata_obj.obs["clone_id"] = adata_obj.obs["clone_id"].astype(str)
    adata_obj.obs.loc[
        ~adata_obj.obs["clone_id"].isna(),
        "clone_id"
    ] = prefix + "-" + adata_obj.obs.loc[
        adata_obj.obs["clone_id"] != "nan",
        "clone_id"
    ]
    return adata_obj

In [None]:
filtered_rna = make_unique_clone_id(guide_adata, SAMPLE)

### Save checkpoint

In [None]:
save_checkpoint(
    adata_obj=guide_adata,
    filename=os.path.join(CHECKPOINT_DIR, f"{PROJECT_NAME}-{SAMPLE}-preprocessed.h5ad"),
    overwrite=True
)