## Generate PTPRC CITE-seq vs. MAS-seq validation plots

**Inputs and Outputs**
- Inputs:
  - long-reads adata object for M132TS incl. CITE-seq data (raw)
  - manual (decision-tree-based) annotations of PTPRC
- Outputs:
  - Figures

In [None]:
import os
import sys

import matplotlib.pylab as plt
import colorcet as cc
import numpy as np
import pandas as pd
from time import time
import logging
import pickle
import gffutils
import pysam
import umap
import scanpy as sc
import pickle

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
# plt.style.use('dark_background')

logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_info = print

sc.set_figure_params(dpi=80)
sc.settings.set_figure_params(dpi=80, facecolor='white')

In [None]:
repo_root = '/home/jupyter/mb-ml-data-disk/MAS-seq-analysis'

# inputs
input_prefix = 'M132TS_immune.revised_v2.harmonized'
output_path = 'output/t-cell-vdj-cite-seq'
final_long_adata_raw_h5_path = os.path.join(repo_root, output_path, f'{input_prefix}.long.stringtie.final.raw.h5ad')
transcript_eq_class_tsv_path = os.path.join(
    repo_root, 'data/t-cell-vdj/long/quant/revised_v2/equivalence_classes',
    'M132TS_MAS_15x_all_cbc_annotated_array_elements_padded.equivalence_class_lookup.tsv')

# # manual isoform annotation quant
# manual_adata_long_annotation_quants_path = os.path.join(repo_root, 'output/t-cell-vdj-cite-seq/manual_annotations')

In [None]:
adata_long_raw = sc.read(final_long_adata_raw_h5_path)

In [None]:
teq_pd = pd.read_csv(transcript_eq_class_tsv_path, delimiter='\t', index_col=0)

In [None]:
from scipy.sparse import hstack as sparse_hstack

def extend_adata(old_adata: sc.AnnData, new_adata: sc.AnnData) -> sc.AnnData:

    old_barcodes = old_adata.obs.index.values
    old_X = old_adata.X
    new_barcodes = new_adata.obs.index.values
    new_X = new_adata.X

    new_barcodes_to_idx_map = {barcode: idx for idx, barcode in enumerate(new_barcodes)}
    assert all(barcode in new_barcodes_to_idx_map for barcode in set(old_barcodes))
    kept_new_barcode_indices = list(map(new_barcodes_to_idx_map.get, old_barcodes))
    new_X_kept = new_X[kept_new_barcode_indices, :]

    merged_var = pd.concat((old_adata.var, new_adata.var))
    try:
        merged_X = sparse_hstack((old_X, new_X_kept)).tocsr()
    except:
        merged_X = np.hstack((old_X, new_X_kept)).copy()

    merged_adata = sc.AnnData(
        X=merged_X,
        obs=old_adata.obs,
        var=merged_var,
        uns=old_adata.uns,
        obsm=old_adata.obsm)
    
    return merged_adata

In [None]:
# # extend adata with manual annotations
# for dirname, _, filenames in os.walk(manual_adata_long_annotation_quants_path):
#     for filename in filenames:
#         manual_adata_long_path = os.path.join(dirname, filename)
#         log_info(f'Adding manual isoform annotations from {manual_adata_long_path} ...')
#         if manual_adata_long_path.split('.')[-1] != 'h5ad':
#             continue
#         manual_adata_long = sc.read(manual_adata_long_path)
#         adata_long_raw = extend_adata(adata_long_raw, manual_adata_long)
# adata_long_raw.X = adata_long_raw.X.tocsr()

In [None]:
adata_long_raw.var[adata_long_raw.var['gene_names'] == 'PTPRC']

In [None]:
gene_name = 'PTPRC'
gene_id = adata_long_raw.var[adata_long_raw.var['gene_names'] == gene_name]['gene_ids'].values[0]

ensembl_map = {
    'CD45RB': ['ENST00000645247', 'ENST00000643513', 'ENST00000530727', 'ENST00000462363', 'ENST00000427110'],
    'CD45RAB': ['ENST00000529828'],
    'CD45RABC': ['ENST00000442510'],
    'CD45RBC': ['ENST00000391970', 'ENST00000367367'],
    'CD45RO': ['ENST00000367379', 'ENST00000348564'],
}

# kept_edge_types = {'c', '=', 'k', 'm', 'n', 'j'}
kept_edge_types = {'c', '=', 'k', 'm', 'n', 'j'}

In [None]:
teq_pd

In [None]:
teq_list = adata_long_raw.var[
    adata_long_raw.var['gene_names'] == gene_name]['transcript_eq_classes'].values.astype(int)

print(f'Number of transcript equivalence classes associated with {gene_name} gene: {len(teq_list)}')

In [None]:
from collections import defaultdict
isoform_name_to_teq_id_map = defaultdict(list)

n_no_hits = 0
n_multi_hits = 0
n_single_hits = 0

for teq_id in teq_list:
    assignment_desc = teq_pd.loc[teq_id].Transcript_Assignments
    assignment_tokens = assignment_desc.split(',')

    # get rid of the version
    annot_ids = [assignment_token.split(';')[0].split('.')[0] for assignment_token in assignment_tokens]
    annot_edge_types = [assignment_token.split(';')[1] for assignment_token in assignment_tokens]
    
    # determine which ensembl annotated isoforms are involved 
    isoform_name_hits = set()
    for isoform_name, ensembl_isoform_id_list in ensembl_map.items():
        for ensembl_isoform_id in ensembl_isoform_id_list:
            for annot_id, annot_edge_type in zip(annot_ids, annot_edge_types):
                if annot_id == ensembl_isoform_id and (len(kept_edge_types) == 0 or annot_edge_type in kept_edge_types):
                    isoform_name_hits.add(isoform_name)
                    
    if len(isoform_name_hits) == 0:
        n_no_hits += 1
        isoform_name_to_teq_id_map['CD45__UNASSIGNED'].append(teq_id)
    elif len(isoform_name_hits) > 1:
        n_multi_hits += 1
        isoform_name_to_teq_id_map['CD45__UNASSIGNED'].append(teq_id)
    else:
        n_single_hits += 1
        isoform_name = next(iter(isoform_name_hits))
        isoform_name_to_teq_id_map[isoform_name].append(teq_id)
        
print(f'Number of {gene_name} equivalence classes with no ENSEMBL hit: {n_no_hits}')
print(f'Number of {gene_name} equivalence classes with multiple ENSEMBL hits: {n_multi_hits}')
print(f'Number of {gene_name} equivalence classes with a single confident ENSEMBL hit: {n_single_hits}')

In [None]:
for isoform_name, teq_indices in isoform_name_to_teq_id_map.items():
    print(isoform_name)
    for teq_index in teq_indices:
        print("    " + teq_pd.loc[teq_index].Transcript_Assignments)

In [None]:
import scipy.sparse as sp

# aggregate counts
aggr_X_list = []
for isoform_name, teq_ids in isoform_name_to_teq_id_map.items():
    var_mask = adata_long_raw.var['transcript_eq_classes'].astype(int).isin(teq_ids)
    aggr_X = np.asarray(adata_long_raw[:, var_mask].X.sum(-1)).flatten()
    aggr_X_list.append(aggr_X)

aggr_X_stack = sp.csr_matrix(np.vstack(aggr_X_list).T)
print(f'aggr_X_stack shape: {aggr_X_stack.shape}')

In [None]:
# make an AnnData
adata_barcodes = adata_long_raw.obs.index.values
barcode_to_adata_row_idx = {barcode: row_idx for row_idx, barcode in enumerate(adata_barcodes)}

prefix = 'aggr__'
adata_isoform_names_list = [prefix + isoform_name for isoform_name in isoform_name_to_teq_id_map.keys()]
new_var = pd.DataFrame(
    dict(
        transcript_eq_classes=adata_isoform_names_list,
        gene_eq_classes=[gene_id] * len(adata_isoform_names_list), 
        transcript_ids=adata_isoform_names_list, 
        gene_ids=[gene_id] * len(adata_isoform_names_list),
        gene_names=[gene_name] * len(adata_isoform_names_list),
        is_de_novo=[False] * len(adata_isoform_names_list),
        is_gene_id_ambiguous=[False] * len(adata_isoform_names_list),
        is_tcr_overlapping=[False] * len(adata_isoform_names_list)),
    index=adata_isoform_names_list)

new_adata = sc.AnnData(
    X=aggr_X_stack,
    var=new_var)

new_adata.obs.index = adata_long_raw.obs.index

In [None]:
new_adata

In [None]:
ext_adata_long_raw = extend_adata(adata_long_raw, new_adata)

In [None]:
ext_adata_long_raw.var

In [None]:
# ... or select isoforms manually
transcript_eq_classes = [
    'aggr__CD45RO',
    'aggr__CD45RBC',
    'aggr__CD45RAB',
    'aggr__CD45RABC',
    'aggr__CD45RB',
    'aggr__CD45__UNASSIGNED',
]

transcript_names_in_fig = [
    'CD45RO',
    'CD45RBC',
    'CD45RAB',
    'CD45RABC',
    'CD45RB',
    'Unassigned'
]

# genes to show total expression alongside the isoforms
gene_names = [
    'HNRNPLL',
]

In [None]:
def plot_embedding_leiden(
        adata: sc.AnnData,
        embedding_key: str,
        leiden_key: str,
        markersize=2,
        alpha=0.75,
        xlabel='UMAP1',
        ylabel='UMAP2',
        label_kwargs=dict(bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)),
        x_offset=dict(),
        y_offset=dict(),
        fig=None,
        ax=None,
        show_labels=True,
        figsize=(3, 3)):
    
    if ax is None or fig is None:
        fig, ax = plt.subplots(figsize=figsize)

    leiden_color_key = f'{leiden_key}_colors'
    assert leiden_color_key in set(adata.uns.keys())

    leiden_category_to_leiden_color_map = {
        leiden_category: leiden_color
        for leiden_color, leiden_category in zip(
            adata.uns[leiden_color_key],
            adata.obs[leiden_key].values.categories)}
    cell_color_list = list(
        map(leiden_category_to_leiden_color_map.get, adata.obs[leiden_key]))

    ax.scatter(
        adata.obsm[embedding_key][:, 0],
        adata.obsm[embedding_key][:, 1],
        color=cell_color_list,
        s=markersize,
        alpha=alpha)

    ax.set_xticks([])
    ax.set_yticks([])

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    if show_labels:
        for leiden_category in adata.obs[leiden_key].values.categories:
            try:
                dx = x_offset[leiden_category]
                dy = y_offset[leiden_category]
            except KeyError:
                dx = 0
                dy = 0
            x_values = adata.obsm[embedding_key][adata.obs[leiden_key] == leiden_category, 0] + dx
            y_values = adata.obsm[embedding_key][adata.obs[leiden_key] == leiden_category, 1] + dy
            x_c, y_c = np.mean(x_values), np.mean(y_values)
            ax.text(
                x_c, y_c, leiden_category,
                fontsize=8,
                ha='center',
                **label_kwargs)

In [None]:
def plot_embedding_continuous(
        adata: sc.AnnData,
        embedding_key: str,
        values: np.ndarray,
        cmap=plt.cm.Blues,
        markersize=2,
        alpha=0.75,
        xlabel='UMAP1',
        ylabel='UMAP2',
        fig=None,
        ax=None,
        sort=True,
        figsize=(3, 3),
        **kwargs):
    
    if ax is None or fig is None:
        fig, ax = plt.subplots(figsize=figsize)

    if sort:
        order = np.argsort(values)
    else:
        order = np.arange(len(values))
    scatter = ax.scatter(
        adata.obsm[embedding_key][order, 0],
        adata.obsm[embedding_key][order, 1],
        c=values[order],
        cmap=cmap,
        s=markersize,
        alpha=alpha,
        **kwargs)

    ax.set_xticks([])
    ax.set_yticks([])

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    
    return scatter

In [None]:
import matplotlib

highlight = '#ff0000'

cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
    'test',
    ['#dcdcdc', '#dcdcdc', highlight])

cmap_1 = matplotlib.colors.LinearSegmentedColormap.from_list(
    'test',
    ['#dcdcdc', highlight])

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, ax = plt.subplots(figsize=(3.3, 3))

scatter = plot_embedding_continuous(
    adata_long_raw,
    'X_umap_SCT_short',
    values=ext_adata_long_raw.obs['CD45_TotalSeqC'],
    ax=ax,
    cmap=cmap_1,
    fig=fig)

ax.set_title('CD45 Total [ab]')
div = make_axes_locatable(ax)
color_axis = div.append_axes("right", size="5%", pad=0.1)
cbar = plt.colorbar(scatter, cax=color_axis)
cbar.ax.tick_params(labelsize=12)

plt.savefig('./output/M132TS__UMAP__AB__CD45TOTAL.pdf', bbox_inches="tight")

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, ax = plt.subplots(figsize=(3.3, 3))

scatter = plot_embedding_continuous(
    ext_adata_long_raw,
    'X_umap_SCT_short',
    values=adata_long_raw.obs['CD45RA_TotalSeqC'],
    ax=ax,
    cmap=cmap,
    fig=fig)

ax.set_title('CD45RA [ab]')
div = make_axes_locatable(ax)
color_axis = div.append_axes("right", size="5%", pad=0.1)
cbar = plt.colorbar(scatter, cax=color_axis)
cbar.ax.tick_params(labelsize=12)

plt.savefig('./output/M132TS__UMAP__AB__CD45RA.pdf', bbox_inches="tight")

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, ax = plt.subplots(figsize=(3.3, 3))

scatter = plot_embedding_continuous(
    ext_adata_long_raw,
    'X_umap_SCT_short',
    values=adata_long_raw.obs['CD45RO_TotalSeqC'],
    ax=ax,
    cmap=cmap,
    fig=fig)

ax.set_title('CD45RO [ab]')
div = make_axes_locatable(ax)
color_axis = div.append_axes("right", size="5%", pad=0.1)
cbar = plt.colorbar(scatter, cax=color_axis)
cbar.ax.tick_params(labelsize=12)

plt.savefig('./output/M132TS__UMAP__AB__CD45RO.pdf', bbox_inches="tight")

In [None]:
for idx in range(len(transcript_eq_classes)):

    transcript_eq_class = transcript_eq_classes[idx]
    transcript_label = transcript_names_in_fig[idx]

    fig, ax = plt.subplots(figsize=(3, 3))

    adata_gray = ext_adata_long_raw.copy()
    adata_gray.uns['mehrtash_leiden_colors'] = ['#dcdcdc'] * len(adata_gray.uns['mehrtash_leiden_colors'])

    plot_embedding_leiden(
        adata_gray,
        'X_umap_SCT_short',
        'mehrtash_leiden',
        show_labels=False,
        alpha=0.25,
        ax=ax,
        fig=fig
    )

    plot_embedding_leiden(
        ext_adata_long_raw[np.asarray(ext_adata_long_raw[:, transcript_eq_class].X.todense()).flatten() > 0, :],
        'X_umap_SCT_short',
        'mehrtash_leiden',
        ax=ax,
        fig=fig,
        markersize=2,
        show_labels=False)

    ax.set_title(transcript_label)

    plt.savefig(f'./output/M132TS__UMAP__CD45__mRNA__{transcript_label}.pdf', bbox_inches="tight")

In [None]:
adata_long_norm = ext_adata_long_raw.copy()
sc.pp.log1p(adata_long_norm)

In [None]:
values = np.asarray(
    adata_long_norm[:, adata_long_norm.var['gene_names'] == 'PTPRC'].X.sum(-1)).flatten()

from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, ax = plt.subplots(figsize=(3.3, 3))

scatter = plot_embedding_continuous(
    adata_long_raw,
    'X_umap_SCT_short',
    values=values,
    ax=ax,
    vmin=-0.4,
    vmax=2.5,
    cmap=plt.cm.Purples,
    fig=fig)

ax.set_title('CD45 Total [ab]')
div = make_axes_locatable(ax)
color_axis = div.append_axes("right", size="5%", pad=0.1)
cbar = plt.colorbar(scatter, cax=color_axis)
cbar.ax.tick_params(labelsize=12)

plt.savefig('./output/M132TS__UMAP__CD45__mRNA__TOTAL.pdf', bbox_inches="tight")

# Plot interesting genes

In [None]:
final_short_adata_sct_h5_path = os.path.join(repo_root, output_path, f'{input_prefix}.short.stringtie.final.sct.h5ad')
adata_short_sct = sc.read_h5ad(final_short_adata_sct_h5_path)

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, ax = plt.subplots(figsize=(3.3, 3))

gene_name = 'PTPRC'

highlight = '#ff0000'

cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
    'test',
    ['#dcdcdc', highlight])

scatter = plot_embedding_continuous(
    adata_short_sct,
    'X_umap_SCT_short',
    values=adata_short_sct[:, adata_short_sct.var.index == gene_name].X.flatten(),
    ax=ax,
    cmap=cmap,
    vmin=0,
    vmax=3,
    fig=fig)

ax.set_title(gene_name)
div = make_axes_locatable(ax)
color_axis = div.append_axes("right", size="5%", pad=0.1)
cbar = plt.colorbar(scatter, cax=color_axis)
cbar.ax.tick_params(labelsize=12)

plt.savefig(f'./output/M132TS__UMAP__GEX__{gene_name}.pdf', bbox_inches="tight")

## Save the extended adata object (with aggregated PTPRC counts)

In [None]:
final_long_adata_raw_ext_h5_path = os.path.join(repo_root, output_path, f'{input_prefix}.long.stringtie.final.raw.ext.h5ad')
ext_adata_long_raw.write(final_long_adata_raw_ext_h5_path)

In [None]:
ext_adata_long_raw.var