## Finalize short and long harmonized AnnData after Seurat SCT for M132TS

This notebooks takes the Seurat SCT output and previously harmonized long-reads and short-reads AnnData objects, performs a final round of filtering (e.g. doublet removal) clustering, and annotation, and outputs four analysis-ready AnnData outputs (short/long, raw/SCT)

**Inputs and Outputs**
- Inputs:
  - harmonized long-reads and short-reads AnnData objects (raw counts, all genes)
  - harmonized long-reads and short-reads AnnData objects (SCT counts, all genes) [from Seurat script]
- Outputs:
  - Four AnnData objects, each including the same cluster annotations and embeddings
    - short raw
    - short SCT
    - long raw
    - long SCT
  - Figures

In [None]:
%matplotlib inline

import matplotlib.pylab as plt

import numpy as np
import pandas as pd
import os
import sys
from time import time
import logging
import pickle
from operator import itemgetter

import scanpy as sc

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_info = logger.warning

import warnings
warnings.filterwarnings("ignore")

sc.settings.set_figure_params(dpi=80, facecolor='white')

In [None]:
repo_root = '/home/jupyter/mb-ml-data-disk/MAS-seq-analysis'

# inputs
input_prefix = 'M132TS_immune.revised_v2.harmonized'
output_path = 'output/t-cell-vdj-cite-seq'

harmonized_short_adata_h5_path = os.path.join(repo_root, output_path, f'{input_prefix}.short.stringtie.h5ad')
harmonized_long_adata_h5_path = os.path.join(repo_root, output_path, f'{input_prefix}.long.stringtie.h5ad')

harmonized_short_adata_seurat_output_h5_path = os.path.join(repo_root, output_path, f'{input_prefix}.short.stringtie.seurat_output.no_mt_pct_regression.h5ad')
harmonized_long_adata_seurat_output_h5_path = os.path.join(repo_root, output_path, f'{input_prefix}.long.stringtie.seurat_output.no_mt_pct_regression.h5ad')

# outputs
final_short_adata_raw_h5_path = os.path.join(repo_root, output_path, f'{input_prefix}.short.stringtie.final.raw.h5ad')
final_short_adata_sct_h5_path = os.path.join(repo_root, output_path, f'{input_prefix}.short.stringtie.final.sct.h5ad')
final_long_adata_raw_h5_path = os.path.join(repo_root, output_path, f'{input_prefix}.long.stringtie.final.raw.h5ad')
final_long_adata_sct_h5_path = os.path.join(repo_root, output_path, f'{input_prefix}.long.stringtie.final.sct.h5ad')

In [None]:
n_sct_features = 3000
n_pcs = 30
n_neighbors = 100
n_neighbors_umap = 100
umap_min_dist = 0.1
umap_spread = 20.0
metric = 'euclidean'

# neglect highly expressed genes for clustering?
neglect_high_expression_fraction = 0.0

## Short adata final clustering (using SCT features)

In [None]:
adata_short_raw = sc.read(os.path.join(repo_root, harmonized_short_adata_h5_path))

adata_short_seurat = sc.read(os.path.join(repo_root, harmonized_short_adata_seurat_output_h5_path))
adata_short_seurat.var.set_index('_index', drop=True, inplace=True)
adata_short_seurat.var.drop(columns=['features'], inplace=True)
adata_short_seurat.var['gene_ids'] = adata_short_raw.var['gene_ids']
adata_short_seurat.obs['percent.mt'] = adata_short_raw.obs['pct_counts_mt']
adata_short_seurat.obs['nCount_RNA'] = adata_short_raw.obs['total_counts']

In [None]:
# highly variable?
adata_short_seurat.var['rv'] = np.var(adata_short_seurat.X, axis=0)
rv_cutoff = np.sort(adata_short_seurat.var['rv'])[::-1][n_sct_features]
adata_short_seurat.var['hv'] = adata_short_seurat.var['rv'] > rv_cutoff

# expression in range?
expr_cutoff = np.sort(adata_short_raw.var['mean_counts'].values)[::-1][
    int(neglect_high_expression_fraction * len(adata_short_raw.var))]
expr_in_range = adata_short_raw.var['mean_counts'] <= expr_cutoff
adata_short_seurat.var['eir'] = expr_in_range[adata_short_seurat.var.index]

# subset to highly variable and expression-in-range features
adata_short_seurat = adata_short_seurat[:, adata_short_seurat.var['hv'] & adata_short_seurat.var['eir']]

# scale
sc.pp.scale(adata_short_seurat)
sc.tl.pca(adata_short_seurat, svd_solver='arpack', n_comps=n_pcs)
sc.pp.neighbors(adata_short_seurat, n_neighbors=n_neighbors, n_pcs=n_pcs, metric=metric)

In [None]:
from umap import UMAP

adata_short_seurat.obsm['X_umap'] = UMAP(
    densmap=False,
    min_dist=umap_min_dist,
    spread=umap_spread,
    n_neighbors=n_neighbors_umap,
    metric=metric).fit_transform(adata_short_seurat.obsm['X_pca'])

adata_short_seurat.obs['CD45_TotalSeqC'] = adata_short_raw.obs['CD45_TotalSeqC']
adata_short_seurat.obs['CD45R_B220_TotalSeqC'] = adata_short_raw.obs['CD45R_B220_TotalSeqC']
adata_short_seurat.obs['CD45RA_TotalSeqC'] = adata_short_raw.obs['CD45RA_TotalSeqC']
adata_short_seurat.obs['CD45RO_TotalSeqC'] = adata_short_raw.obs['CD45RO_TotalSeqC']

In [None]:
sc.pl.umap(adata_short_seurat, color=['CD45RA_TotalSeqC', 'CD45RO_TotalSeqC'], s=20, vmax=8)

In [None]:
sc.pl.umap(adata_short_seurat, color=['percent.mt'], s=20, vmax=20)
sc.pl.umap(adata_short_seurat, color=['nCount_RNA'], s=20, vmax=10000)

## Preliminary clustering, doublet scrubbing, and proof-reading

In [None]:
sc.tl.leiden(adata_short_seurat, resolution=1.3, key_added='mehrtash_leiden')
sc.pl.umap(adata_short_seurat, color=['mehrtash_leiden'], s=20)

In [None]:
import scrublet as scr

scrub = scr.Scrublet(adata_short_raw.X, expected_doublet_rate=0.10)
doublet_scores, predicted_doublets = scrub.scrub_doublets()

adata_short_seurat.obs['doublet_scores'] = doublet_scores
adata_short_seurat.obs['predicted_doublets'] = predicted_doublets.astype(np.int)
adata_short_seurat.obs['predicted_doublets_0.20'] = (doublet_scores > 0.20).astype(np.int)
adata_short_seurat.obs['predicted_doublets_0.25'] = (doublet_scores > 0.25).astype(np.int)
adata_short_seurat.obs['predicted_doublets_0.30'] = (doublet_scores > 0.30).astype(np.int)
adata_short_seurat.obs['predicted_doublets_0.35'] = (doublet_scores > 0.35).astype(np.int)
adata_short_seurat.obs['predicted_doublets_0.40'] = (doublet_scores > 0.40).astype(np.int)

In [None]:
doublet_keys = [
    'predicted_doublets',
    'predicted_doublets_0.20',
    'predicted_doublets_0.25',
    'predicted_doublets_0.30',
    'predicted_doublets_0.35',
    'predicted_doublets_0.40'
]

sc.pl.umap(adata_short_seurat, color=doublet_keys, s=20)

In [None]:
for key in doublet_keys:
    log_info(f'{key}, doublet fraction: {np.sum(adata_short_seurat.obs[key]) / len(adata_short_seurat):.3f}')

In [None]:
adata_short_seurat_full = sc.read(os.path.join(repo_root, harmonized_short_adata_seurat_output_h5_path))
adata_short_seurat_full.var.set_index('_index', drop=True, inplace=True)
adata_short_seurat_full.var.drop(columns=['features'], inplace=True)
adata_short_seurat_full.var['gene_ids'] = adata_short_seurat.var['gene_ids']
adata_short_seurat_full.obs['percent.mt'] = adata_short_seurat.obs['percent.mt']
adata_short_seurat_full.obs['nCount_RNA'] = adata_short_seurat.obs['nCount_RNA']
adata_short_seurat_full.obs['mehrtash_leiden'] = adata_short_seurat.obs['mehrtash_leiden']
adata_short_seurat_full.raw = None
adata_test = adata_short_seurat_full

sc.tl.rank_genes_groups(adata_test, 'mehrtash_leiden', method='t-test')
sc.pl.rank_genes_groups(adata_test, n_genes=20, sharey=False)

In [None]:
unwanted_leiden_ids = {
    '8', # MT- high
}

final_barcodes = adata_short_seurat[
    (~adata_short_seurat.obs['mehrtash_leiden'].isin(unwanted_leiden_ids)) &
    (~adata_short_seurat.obs['predicted_doublets_0.25'].astype(np.bool))].obs.index

## Final clustering

In [None]:
adata_short_raw = sc.read(os.path.join(repo_root, harmonized_short_adata_h5_path))[final_barcodes]
adata_short_seurat = sc.read(os.path.join(repo_root, harmonized_short_adata_seurat_output_h5_path))[final_barcodes]
adata_short_seurat.var.set_index('_index', drop=True, inplace=True)
adata_short_seurat.var.drop(columns=['features'], inplace=True)
adata_short_seurat.var['gene_ids'] = adata_short_raw.var['gene_ids']
adata_short_seurat.obs['percent.mt'] = adata_short_raw.obs['pct_counts_mt']
adata_short_seurat.obs['nCount_RNA'] = adata_short_raw.obs['total_counts']

adata_short_seurat.var['rv'] = np.var(adata_short_seurat.X, axis=0)
rv_cutoff = np.sort(adata_short_seurat.var['rv'])[::-1][n_sct_features]
adata_short_seurat.var['hv'] = adata_short_seurat.var['rv'] > rv_cutoff
adata_short_seurat = adata_short_seurat[:, adata_short_seurat.var['hv']]

sc.pp.scale(adata_short_seurat)
sc.tl.pca(adata_short_seurat, svd_solver='arpack', n_comps=n_pcs)
sc.pp.neighbors(adata_short_seurat, n_neighbors=n_neighbors, n_pcs=n_pcs, metric=metric)

adata_short_seurat.obsm['X_umap'] = UMAP(
    random_state=1,
    densmap=False,
    min_dist=umap_min_dist,
    spread=umap_spread,
    n_neighbors=n_neighbors_umap,
    metric=metric).fit_transform(adata_short_seurat.obsm['X_pca'])

adata_short_seurat.obs['CD45_TotalSeqC'] = adata_short_raw.obs['CD45_TotalSeqC']
adata_short_seurat.obs['CD45R_B220_TotalSeqC'] = adata_short_raw.obs['CD45R_B220_TotalSeqC']
adata_short_seurat.obs['CD45RA_TotalSeqC'] = adata_short_raw.obs['CD45RA_TotalSeqC']
adata_short_seurat.obs['CD45RO_TotalSeqC'] = adata_short_raw.obs['CD45RO_TotalSeqC']

sc.pl.umap(adata_short_seurat, color=['CD45RA_TotalSeqC', 'CD45RO_TotalSeqC'], s=20, vmax=8)

In [None]:
sc.tl.leiden(adata_short_seurat, resolution=1.15, key_added='mehrtash_leiden')
sc.pl.umap(adata_short_seurat, color=['mehrtash_leiden'], s=20)

In [None]:
# map leiden labels to cell type names
mehrtash_leiden_names_map = {
    '0': 'SMC',
    '1': 'A/EE',
    '2': 'EA',
    '3': 'TD II',
    '4': 'TD I',
    '5': 'CE',
    '6': 'P',
    '7': 'EA II'
}

leiden_name_to_color_map = {
    'SMC': '#00acc6',
    'EA': '#018700',
    'A/EE': '#8c3bff',
    'TD II': '#6b004f',
    'TD I': '#eb0077',
    'CE': '#ff7ed1',
    'P': '#ffa52f',
    'EA II': '#708297',
}

In [None]:
# change leiden labels to cell type names
adata_short_seurat.obs['mehrtash_leiden'] = pd.Series(
    index=adata_short_seurat.obs.index,
    data=list(map(mehrtash_leiden_names_map.get, adata_short_seurat.obs['mehrtash_leiden'].values)),
    dtype='category')

# set colors
adata_short_seurat.uns['mehrtash_leiden_colors'] = list(
    map(leiden_name_to_color_map.get, adata_short_seurat.obs['mehrtash_leiden'].values.categories.values))

In [None]:
# import colorcet as cc

# adata_short_seurat.uns['mehrtash_leiden_colors'] = leiden_color_map

# def rgb_to_hex(x) -> str: 
#     r = max(0, min(int(255 * x[0]), 255))
#     g = max(0, min(int(255 * x[1]), 255))
#     b = max(0, min(int(255 * x[2]), 255))
#     return "#{0:02x}{1:02x}{2:02x}".format(r, g, b)
# leiden_categories = adata_short_seurat.obs['mehrtash_leiden'].values.categories
# n_leiden = len(leiden_categories)
# 
# leiden_color_map = {
#     color: category
#     for category, color in zip(
#         adata_short_seurat.obs['mehrtash_leiden'].values.categories,
#         leiden_color_list)}

sc.pl.umap(adata_short_seurat, color=['mehrtash_leiden'], s=20)

In [None]:
adata_short_seurat_full = sc.read(os.path.join(repo_root, harmonized_short_adata_seurat_output_h5_path))[final_barcodes]
adata_short_seurat_full.obs['mehrtash_leiden'] = adata_short_seurat.obs['mehrtash_leiden']
adata_short_seurat_full.var.set_index('_index', drop=True, inplace=True)
adata_short_seurat_full.var.drop(columns=['features'], inplace=True)
adata_short_seurat_full.var['gene_ids'] = adata_short_seurat.var['gene_ids']
adata_short_seurat_full.obs['percent.mt'] = adata_short_seurat.obs['percent.mt']
adata_short_seurat_full.obs['nCount_RNA'] = adata_short_seurat.obs['nCount_RNA']
adata_short_seurat_full.obs['mehrtash_leiden'] = adata_short_seurat.obs['mehrtash_leiden']
adata_short_seurat_full.raw = None
adata_test = adata_short_seurat_full

sc.tl.rank_genes_groups(adata_test, 'mehrtash_leiden', method='t-test')
sc.pl.rank_genes_groups(adata_test, n_genes=20, sharey=False)

In [None]:
# leiden_names = []
# for leiden_id in adata_short_seurat.obs['mehrtash_leiden'].values:
#     leiden_names.append(mehrtash_leiden_names_map[leiden_id])
# new_color_map = {
#     color: mehrtash_leiden_names_map[old_category]
#     for color, old_category in adata_short_seurat.uns['mehrtash_leiden_colors'].items()}
# adata_short_seurat.obs['mehrtash_leiden'] = pd.Categorical(leiden_names)
# adata_short_seurat.uns['mehrtash_leiden_colors'] = new_color_map

## Long adata final clustering (using SCT features)

In [None]:
n_sct_features = 5000
n_pcs = 30
n_neighbors = 100
n_neighbors_umap = 100
umap_min_dist = 0.1
umap_spread = 20.0
metric = 'euclidean'

In [None]:
adata_long_raw = sc.read(os.path.join(repo_root, harmonized_long_adata_h5_path))[final_barcodes]
adata_long_seurat = sc.read(os.path.join(repo_root, harmonized_long_adata_seurat_output_h5_path))[final_barcodes]

In [None]:
adata_long_seurat.var.set_index('_index', drop=True, inplace=True)
adata_long_seurat.var.drop(columns=['features'], inplace=True)
adata_long_seurat.var['transcript_eq_classes'] = adata_long_raw.var['transcript_eq_classes']
adata_long_seurat.var['gene_eq_classest'] = adata_long_raw.var['gene_eq_classes']
adata_long_seurat.var['transcript_ids'] = adata_long_raw.var['transcript_ids']
adata_long_seurat.var['gene_ids'] = adata_long_raw.var['gene_ids']
adata_long_seurat.var['gene_names'] = adata_long_raw.var['gene_names']
adata_long_seurat.var['is_de_novo'] = adata_long_raw.var['is_de_novo']
adata_long_seurat.var['is_gene_id_ambiguous'] = adata_long_raw.var['is_gene_id_ambiguous']
adata_long_seurat.var['is_tcr_overlapping'] = adata_long_raw.var['is_tcr_overlapping']

In [None]:
adata_long_seurat.var['rv'] = np.var(adata_long_seurat.X, axis=0)
rv_cutoff = np.sort(adata_long_seurat.var['rv'])[::-1][n_sct_features]
adata_long_seurat.var['hv'] = adata_long_seurat.var['rv'] > rv_cutoff
adata_long_seurat = adata_long_seurat[:, adata_long_seurat.var['hv']]

sc.pp.scale(adata_long_seurat)
sc.tl.pca(adata_long_seurat, svd_solver='arpack', n_comps=n_pcs)
sc.pp.neighbors(adata_long_seurat, n_neighbors=n_neighbors, n_pcs=n_pcs, metric=metric)

In [None]:
from umap import UMAP

adata_long_seurat.obsm['X_umap'] = UMAP(
    densmap=False,
    min_dist=umap_min_dist,
    spread=umap_spread,
    n_neighbors=n_neighbors_umap,
    init=adata_short_seurat.obsm['X_umap'],
    metric=metric).fit_transform(adata_long_seurat.obsm['X_pca'])

In [None]:
sc.pl.umap(adata_long_seurat)

In [None]:
adata_long_seurat.obs['CD45_TotalSeqC'] = adata_long_raw.obs['CD45_TotalSeqC']
adata_long_seurat.obs['CD45R_B220_TotalSeqC'] = adata_long_raw.obs['CD45R_B220_TotalSeqC']
adata_long_seurat.obs['CD45RA_TotalSeqC'] = adata_long_raw.obs['CD45RA_TotalSeqC']
adata_long_seurat.obs['CD45RO_TotalSeqC'] = adata_long_raw.obs['CD45RO_TotalSeqC']
adata_long_seurat.obs['mehrtash_leiden'] = adata_short_seurat.obs['mehrtash_leiden']
adata_long_seurat.uns['mehrtash_leiden_colors'] = adata_short_seurat.uns['mehrtash_leiden_colors']

In [None]:
sc.pl.umap(adata_long_seurat, color=['CD45_TotalSeqC', 'CD45R_B220_TotalSeqC', 'CD45RA_TotalSeqC', 'CD45RO_TotalSeqC', 'mehrtash_leiden'])

## Save

In [None]:
mehrtash_leiden = adata_short_seurat.obs['mehrtash_leiden']
mehrtash_leiden_colors = adata_short_seurat.uns['mehrtash_leiden_colors']

adata_long_X_pca_SCT = adata_long_seurat.obsm['X_pca']
adata_long_X_umap_SCT = adata_long_seurat.obsm['X_umap']

adata_short_X_pca_SCT = adata_short_seurat.obsm['X_pca']
adata_short_X_umap_SCT = adata_short_seurat.obsm['X_umap']

In [None]:
adata_short_raw = sc.read(os.path.join(repo_root, harmonized_short_adata_h5_path))[final_barcodes]

adata_short_X_pca_raw = adata_short_raw.obsm['X_pca']
adata_short_X_tsne_raw = adata_short_raw.obsm['X_tsne']

adata_short_raw.uns.clear()
adata_short_raw.obsm.clear()

adata_short_raw.obsm['X_pca_SCT_short'] = adata_short_X_pca_SCT
adata_short_raw.obsm['X_pca_SCT_long'] = adata_long_X_pca_SCT

adata_short_raw.obsm['X_umap_SCT_short'] = adata_short_X_umap_SCT
adata_short_raw.obsm['X_umap_SCT_long'] = adata_long_X_umap_SCT

adata_short_raw.obsm['X_tsne_raw_short'] = adata_short_X_tsne_raw
adata_short_raw.obsm['X_pca_raw_short'] = adata_short_X_pca_raw

adata_short_raw.obs['mehrtash_leiden'] = mehrtash_leiden
adata_short_raw.uns['mehrtash_leiden_colors'] = mehrtash_leiden_colors

In [None]:
adata_short_sct = sc.read(os.path.join(repo_root, harmonized_short_adata_seurat_output_h5_path))[final_barcodes]
adata_short_sct.var.set_index('_index', drop=True, inplace=True)

adata_short_sct.raw = None
adata_short_sct.uns.clear()
adata_short_sct.varm.clear()
adata_short_sct.obsp.clear()
adata_short_sct.obs = adata_short_raw.obs
adata_short_sct.obsm = adata_short_raw.obsm
adata_short_sct.uns = adata_short_raw.uns
adata_short_sct.var = adata_short_raw[:, adata_short_sct.var.index].var

In [None]:
adata_long_raw = sc.read(os.path.join(repo_root, harmonized_long_adata_h5_path))[final_barcodes]

adata_long_raw.uns.clear()
adata_long_raw.obsm.clear()

adata_long_raw.obsm['X_pca_SCT_short'] = adata_short_X_pca_SCT
adata_long_raw.obsm['X_pca_SCT_long'] = adata_long_X_pca_SCT

adata_long_raw.obsm['X_umap_SCT_short'] = adata_short_X_umap_SCT
adata_long_raw.obsm['X_umap_SCT_long'] = adata_long_X_umap_SCT

adata_long_raw.obsm['X_tsne_raw_short'] = adata_short_X_tsne_raw
adata_long_raw.obsm['X_pca_raw_short'] = adata_short_X_pca_raw

adata_long_raw.obs['mehrtash_leiden'] = mehrtash_leiden
adata_long_raw.uns['mehrtash_leiden_colors'] = mehrtash_leiden_colors

In [None]:
adata_long_sct = sc.read(os.path.join(repo_root, harmonized_long_adata_seurat_output_h5_path))[final_barcodes]
adata_long_sct.var.set_index('_index', drop=True, inplace=True)

adata_long_sct.raw = None
adata_long_sct.uns.clear()
adata_long_sct.varm.clear()
adata_long_sct.obsp.clear()
adata_long_sct.obs = adata_long_raw.obs
adata_long_sct.obsm = adata_long_raw.obsm
adata_long_sct.uns = adata_long_raw.uns
adata_long_sct.var = adata_long_raw[:, adata_long_sct.var.index].var

In [None]:
adata_short_raw.write(final_short_adata_raw_h5_path)
adata_short_sct.write(final_short_adata_sct_h5_path)
adata_long_raw.write(final_long_adata_raw_h5_path)
adata_long_sct.write(final_long_adata_sct_h5_path)

## Make plots

In [None]:
adata_short_raw = sc.read_h5ad(final_short_adata_raw_h5_path)
adata_short_sct = sc.read_h5ad(final_short_adata_sct_h5_path)
adata_long_raw = sc.read_h5ad(final_long_adata_raw_h5_path)
adata_long_sct = sc.read_h5ad(final_long_adata_sct_h5_path)

In [None]:
adata_long_raw.X.sum() / len(adata_long_raw)

In [None]:
def plot_embedding(
        adata: sc.AnnData,
        embedding_key: str,
        leiden_key: str,
        markersize=2,
        alpha=0.75,
        x_offset=dict(),
        y_offset=dict(),
        fig=None,
        ax=None,
        show_labels=True,
        figsize=(3, 3)):
    
    if ax is None or fig is None:
        fig, ax = plt.subplots(figsize=figsize)

    leiden_color_key = f'{leiden_key}_colors'
    assert leiden_color_key in set(adata.uns.keys())

    leiden_category_to_leiden_color_map = {
        leiden_category: leiden_color
        for leiden_color, leiden_category in zip(
            adata.uns[leiden_color_key],
            adata.obs[leiden_key].values.categories)}
    cell_color_list = list(
        map(leiden_category_to_leiden_color_map.get, adata.obs[leiden_key]))

    ax.scatter(
        adata.obsm[embedding_key][:, 0],
        adata.obsm[embedding_key][:, 1],
        color=cell_color_list,
        s=markersize,
        alpha=alpha)

    ax.set_xticks([])
    ax.set_yticks([])

    ax.set_xlabel('UMAP1')
    ax.set_ylabel('UMAP2')

    if show_labels:
        for leiden_category in adata.obs[leiden_key].values.categories:
            try:
                dx = x_offset[leiden_category]
                dy = y_offset[leiden_category]
            except KeyError:
                dx = 0
                dy = 0
            x_values = adata.obsm[embedding_key][adata.obs[leiden_key] == leiden_category, 0] + dx
            y_values = adata.obsm[embedding_key][adata.obs[leiden_key] == leiden_category, 1] + dy
            x_c, y_c = np.mean(x_values), np.mean(y_values)
            ax.text(
                x_c, y_c, leiden_category,
                fontsize=8,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

In [None]:
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)

In [None]:
adata = adata_short_raw
embedding_key = 'X_umap_SCT_short'
leiden_key = 'mehrtash_leiden'
x_offset = {'A IV': -22}
y_offset = {'A IV': -2}

fig, ax = plt.subplots(figsize=(3, 3))

plot_embedding(
    adata, embedding_key, leiden_key,
    markersize=2,
    alpha=0.5,
    x_offset=x_offset,
    y_offset=y_offset,
    fig=fig,
    show_labels=False,
    ax=ax)

fig.tight_layout()

plt.savefig('./output/M132TS__UMAP__short.pdf')

In [None]:
adata = adata_long_raw
embedding_key = 'X_umap_SCT_long'
leiden_key = 'mehrtash_leiden'
x_offset = {'A IV': -7}
y_offset = {'A IV': -16}

fig, ax = plt.subplots(figsize=(3, 3))

plot_embedding(
    adata, embedding_key, leiden_key,
    markersize=2,
    alpha=0.5,
    x_offset=x_offset,
    y_offset=y_offset,
    show_labels=False,
    fig=fig,
    ax=ax)

plt.savefig('./output/M132TS__UMAP__long.pdf')

## Gene-level concordance between short and long

In [None]:
ADATA_SHORT_GENE_IDS_COL = 'gene_ids'
ADATA_LONG_GENE_IDS_COL = 'gene_ids'
LEIDEN_OBS_COL = 'mehrtash_leiden'

In [None]:
barcode_groups_dict = dict()
barcode_groups_dict['All T Cells'] = np.ones((len(adata_long_raw),), dtype=np.bool)
for leiden_label in adata_long_raw.obs[LEIDEN_OBS_COL].values.categories:
    barcode_groups_dict[leiden_label] = adata_long_raw.obs[LEIDEN_OBS_COL].values == leiden_label

ax_title_dict = {
    'All T Cells': 'All T Cells',
    'A/EE': 'Activated / Early Exhausted',
    'CE': 'Cytotoxic Effector',
    'EA': 'Early Activated',
    'EA II': 'Early Activated II',
    'P': 'Proliferating',
    'SMC': 'Stem-like Memory',
    'TD I': 'Termially Differentiated I',
    'TD II': 'Terminally Differentiated II'
}

output_suffix_dict = {
    'All T Cells': 'all',
    'A/EE': 'a_ee',
    'CE': 'ce',
    'EA': 'ea',
    'EA II': 'ea2',
    'P': 'p',
    'SMC': 'smc',
    'TD I': 'td1',
    'TD II': 'td2'
}

color_map = {
    adata_long_raw.obs[LEIDEN_OBS_COL].values.categories[i]: adata_long_raw.uns[f'{LEIDEN_OBS_COL}_colors'][i]
    for i in range(len(adata_long_raw.obs[LEIDEN_OBS_COL].values.categories))}
color_map['All T Cells'] = 'black'

In [None]:
from itertools import groupby
from operator import itemgetter
import matplotlib.ticker as tck
from sklearn.metrics import r2_score

# drop gencode version suffix ...
drop_version = lambda entry: entry.split('.')[0] if entry.find('ENS') == 0 else entry

for barcode_group_key in barcode_groups_dict.keys():

    barcode_mask = barcode_groups_dict[barcode_group_key]
    
    total_tx_expr_long = np.asarray(adata_long_raw[barcode_mask].X.sum(0)).flatten()
    total_gene_expr_short = np.asarray(adata_short_raw[barcode_mask].X.sum(0)).flatten()
    
    short_gene_ids = list(map(drop_version, adata_short_raw.var[ADATA_SHORT_GENE_IDS_COL].values))
    long_gene_ids = list(map(drop_version, adata_long_raw.var[ADATA_LONG_GENE_IDS_COL].values))
    mutual_gene_ids_set = set(long_gene_ids).intersection(short_gene_ids)
    mutual_gene_ids_list = list(mutual_gene_ids_set)
    
    gene_id_to_tx_indices_map = dict()
    for g in groupby(sorted(list(enumerate(long_gene_ids)), key=itemgetter(1)), key=itemgetter(1)):
        gene_id = g[0]
        tx_indices = list(map(itemgetter(0), g[1]))    
        gene_id_to_tx_indices_map[gene_id] = tx_indices

    total_gene_expr_long = []
    for gene_id in mutual_gene_ids_list:
        total_gene_expr_long.append(np.sum(total_tx_expr_long[gene_id_to_tx_indices_map[gene_id]]))
    total_gene_expr_long = np.asarray(total_gene_expr_long)
    short_gene_ids_to_idx_map = {
        gene_id: idx for idx, gene_id in enumerate(short_gene_ids)}
    mutual_indices_in_short = list(map(short_gene_ids_to_idx_map.get, mutual_gene_ids_list))
    total_gene_expr_short = total_gene_expr_short[mutual_indices_in_short]
    
    total_gene_expr_short_tpm = 1_000_000 * total_gene_expr_short / np.sum(total_gene_expr_short)
    total_gene_expr_long_tpm = 1_000_000 * total_gene_expr_long / np.sum(total_gene_expr_long)
    

    fig, ax = plt.subplots(figsize=(3.5, 3.5))

    ax.plot([1e-1, 1e5], [1e-1, 1e5], '--', lw=1, color='black')
    
    ax.scatter(total_gene_expr_short_tpm, total_gene_expr_long_tpm, s=1, alpha=0.2, color=color_map[barcode_group_key], rasterized=True)
    r2 = r2_score(np.log1p(total_gene_expr_short_tpm), np.log1p(total_gene_expr_long_tpm))
    ax.text(0.15, 3e4, f'$R^2$ = {r2:.2f}', fontsize=10)
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xticks([1e-1, 1e1, 1e3, 1e5])
    ax.set_yticks([1e-1, 1e1, 1e3, 1e5])
    ax.xaxis.set_minor_locator(tck.AutoMinorLocator())
    ax.yaxis.set_minor_locator(tck.AutoMinorLocator())
    ax.set_xlim((1e-1, 1e5))
    ax.set_ylim((1e-1, 1e5))
    ax.set_aspect('equal')
    # ax.set_title('M132TS')
    ax.set_xlabel('Short-reads GEX (TPM)')
    ax.set_ylabel('MAS-ISO-seq GEX (TPM)')
    ax.set_title(ax_title_dict[barcode_group_key])
    
    fig.tight_layout()

    plt.savefig(f'./output/M132TS__short_long_gex_concordance__{output_suffix_dict[barcode_group_key]}.pdf')