## Harmonize long-read and short-read AnnData objects for M132TS

This notebooks harmonizes long-read and short-read AnnData objects (e.g. mutual barcodes, mutual genes, etc.) and produces several AnnData outputs and figures, including a scrubbed AnnData object for Seurat SCT analysis.

**Inputs and Outputs**
- Inputs:
  - `M132TS_immune.h5ad`: short-reads AnnData object (immune component)
  - `M132TS_MAS_15x_m64020e_210506_132139_gene_tx_expression_count_matrix_tx_gene_counts_adata.h5ad`: long-reads counts matrix (raw)
- Outputs:
  - harmonized long-reads and short-reads AnnData objects (raw counts, all genes)
  - harmonized long-reads and short-reads AnnData objects (raw counts, all genes, metadata and unstructured data removed [for Seurat SCT analysis])
  - harmonized long-reads and short-reads AnnData objects (raw counts, mutual genes)
  - short vs. long gene expression concordance
  - gene- and transcript- level saturation curves + fit

In [None]:
%matplotlib inline

import matplotlib.pylab as plt

import numpy as np
import pandas as pd
import os
import sys
from time import time
import logging
import pickle
from operator import itemgetter

import scanpy as sc

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_info = logger.warning

import warnings
warnings.filterwarnings("ignore")

sc.settings.set_figure_params(dpi=80, facecolor='white')

In [None]:
# constants
ADATA_SHORT_GENE_IDS_COL = 'gene_ids'
ADATA_LONG_GENE_IDS_COL = 'gene_ids'
KNOWN_GENE_PREFIX = 'ENSG'
DE_NOVO_GENE_PREFIX = 'MASG'

# traits to propagate from short adata to long adata
adata_short_obs_col_propagate = [
    'CD45_TotalSeqC',
    'CD45R_B220_TotalSeqC',
    'CD45RA_TotalSeqC',
    'CD45RO_TotalSeqC',
    'leiden_crude']

adata_short_obsm_col_propagate = [
    'X_pca',
    'X_umap']

# input
repo_root = '/home/jupyter/mb-ml-data-disk/MAS-seq-analysis'
short_h5_path = os.path.join(
    repo_root, 'output/t-cell-vdj-cite-seq/M132TS_both.h5ad')
long_h5_path = os.path.join(
    repo_root, 'data/t-cell-vdj/long/quant/revised_v2/M132TS_MAS_15x_overall_gene_tx_expression_count_matrix_tx_gene_counts_adata.h5ad')

# output
output_prefix = 'M132TS_both_new_pipeline__revised_v2'
output_path = os.path.join(
    repo_root, 'output/t-cell-vdj-cite-seq')

harmonized_long_adata_h5_path = os.path.join(
    output_path, f'{output_prefix}.harmonized.barnyard.long.h5ad')
harmonized_short_adata_h5_path = os.path.join(
    output_path, f'{output_prefix}.harmonized.barnyard.short.h5ad')

harmonized_long_adata_mutual_genes_h5_path = os.path.join(
    output_path, f'{output_prefix}.harmonized.mutual_genes.barnyard.long.h5ad')
harmonized_short_adata_mutual_genes_h5_path = os.path.join(
    output_path, f'{output_prefix}.harmonized.mutual_genes.barnyard.short.h5ad')

In [None]:
# load data
adata_short = sc.read(short_h5_path).raw.to_adata()
adata_long = sc.read(long_h5_path)
adata_long.var_names_make_unique()
adata_long.obs = adata_long.obs.drop('Cell Barcode', axis=1)

In [None]:
# metrics container
metrics = dict()

In [None]:
n_barcodes_short = adata_short.X.shape[0]
n_barcodes_long = adata_long.X.shape[0]
metrics['n_barcodes_short'] = n_barcodes_short
metrics['n_barcodes_long'] = n_barcodes_long

In [None]:
adata_gene_info_set = set(zip(
    adata_long.var[ADATA_LONG_GENE_IDS_COL].values,
    adata_long.var['is_de_novo'].values,
    adata_long.var['is_gene_id_ambiguous'].values))

gene_ids_set = set(map(itemgetter(0), adata_gene_info_set))
n_total_genes_long = len(gene_ids_set)
n_total_genes_short = len(adata_short.var)
n_de_novo_genes_long = sum(map(lambda gene_id: gene_id.find(DE_NOVO_GENE_PREFIX) == 0, gene_ids_set))
n_gencode_genes_long = sum(map(lambda gene_id: gene_id.find(KNOWN_GENE_PREFIX) == 0, gene_ids_set))

metrics['n_total_genes_long'] = n_total_genes_long
metrics['n_total_genes_short'] = n_total_genes_short
metrics['n_gencode_genes_long'] = n_total_genes_long
metrics['n_de_novo_genes_long'] = n_total_genes_long

log_info(f'All short-reads adata genes: {n_total_genes_short}')
log_info(f'All long-reads adata genes: {n_total_genes_long}')
log_info(f'Known long-reads adata genes: {n_gencode_genes_long}')
log_info(f'de novo long-rads adata genes: {n_de_novo_genes_long}')

In [None]:
from collections import Counter

adata_short_gene_id_set = set(adata_short.var[ADATA_SHORT_GENE_IDS_COL].values)
adata_long_gene_id_set = set(adata_long.var[ADATA_LONG_GENE_IDS_COL].values)

# drop gencode version suffix ...
drop_version = lambda entry: entry.split('.')[0] if entry.find('ENS') == 0 else entry

unversioned_adata_short_gene_id_counter = Counter([
    drop_version(entry) for entry in adata_short_gene_id_set])
unversioned_adata_long_gene_id_counter = Counter([
    drop_version(entry) for entry in adata_long_gene_id_set])

ver_unambiguous_adata_short_gene_id_list = [
    gene_id for gene_id in unversioned_adata_short_gene_id_counter.keys()
    if unversioned_adata_short_gene_id_counter[gene_id] == 1]
ver_unambiguous_adata_long_gene_id_list = [
    gene_id for gene_id in unversioned_adata_long_gene_id_counter.keys()
    if unversioned_adata_long_gene_id_counter[gene_id] == 1]

gene_id_ambiguous_adata_long_unversioned_gene_id_set = set(map(
    drop_version,
    adata_long[:, adata_long.var['is_gene_id_ambiguous']].var[ADATA_LONG_GENE_IDS_COL].values))

final_unversioned_unambiguous_mutual_gene_id_set = \
    set(ver_unambiguous_adata_long_gene_id_list) \
    .intersection(ver_unambiguous_adata_short_gene_id_list) \
    .difference(gene_id_ambiguous_adata_long_unversioned_gene_id_set)

metrics['n_adata_short_gene_id_set'] = len(adata_short_gene_id_set)
metrics['n_adata_long_gene_id_set'] = len(adata_long_gene_id_set)
metrics['n_ver_unambiguous_adata_short'] = len(ver_unambiguous_adata_short_gene_id_list)
metrics['n_ver_unambiguous_adata_long'] = len(ver_unambiguous_adata_long_gene_id_list)
metrics['n_gene_id_ambiguous_adata'] = len(gene_id_ambiguous_adata_long_unversioned_gene_id_set)
metrics['n_final_unversioned_unambiguous_mutual_gene_id_set'] = len(final_unversioned_unambiguous_mutual_gene_id_set)

log_info(f'n_adata_short_gene_id_set: {len(adata_short_gene_id_set)}')
log_info(f'n_adata_long_gene_id_set: {len(adata_long_gene_id_set)}')
log_info(f'n_ver_unambiguous_adata_short: {len(ver_unambiguous_adata_short_gene_id_list)}')
log_info(f'n_ver_unambiguous_adata_long: {len(ver_unambiguous_adata_long_gene_id_list)}')
log_info(f'n_gene_id_ambiguous_adata: {len(gene_id_ambiguous_adata_long_unversioned_gene_id_set)}')
log_info(f'n_final_unversioned_unambiguous_mutual_gene_id_set: {len(final_unversioned_unambiguous_mutual_gene_id_set)}')

In [None]:
final_adata_short_mutual_keep_var_indices = [
    var_idx
    for var_idx, gene_id in enumerate(adata_short.var[ADATA_SHORT_GENE_IDS_COL])
    if drop_version(gene_id) in final_unversioned_unambiguous_mutual_gene_id_set]

final_adata_long_mutual_keep_var_indices = [
    var_idx
    for var_idx, gene_id in enumerate(adata_long.var[ADATA_LONG_GENE_IDS_COL])
    if drop_version(gene_id) in final_unversioned_unambiguous_mutual_gene_id_set]

# sort both by gene_ids
final_adata_short_mutual_keep_var_indices = sorted(
    final_adata_short_mutual_keep_var_indices,
    key=lambda idx: drop_version(adata_short.var[ADATA_SHORT_GENE_IDS_COL].values[idx]))

final_adata_long_mutual_keep_var_indices = sorted(
    final_adata_long_mutual_keep_var_indices,
    key=lambda idx: drop_version(adata_long.var[ADATA_LONG_GENE_IDS_COL].values[idx]))

In [None]:
# subset long adata barcodes to short adata
adata_short_barcodes_set = set(adata_short.obs.index.values)
adata_long_keep_indices = []
found_barcodes_set = set()
for idx, bc in enumerate(adata_long.obs.index.values):
    if bc in adata_short_barcodes_set:
        adata_long_keep_indices.append(idx)
        found_barcodes_set.add(bc)
not_found_barcodes_set = adata_short_barcodes_set.difference(found_barcodes_set)

if len(not_found_barcodes_set) > 0:
    log_info(f'{len(not_found_barcodes_set)} out of {len(adata_short_barcodes_set)} could not be found in the long reads adata!')
else:
    log_info(f'All {len(adata_short_barcodes_set)} barcodes could be found in the long reads adata.')

found_barcodes_list = sorted(list(found_barcodes_set))

adata_short_barcode_index_map = {
    bc: idx for idx, bc in enumerate(adata_short.obs.index.values)}
final_adata_short_keep_obs_indices = [
    adata_short_barcode_index_map[barcode]
    for barcode in found_barcodes_list]

adata_long_barcode_index_map = {
    bc: idx for idx, bc in enumerate(adata_long.obs.index.values)}
final_adata_long_keep_obs_indices = [
    adata_long_barcode_index_map[barcode]
    for barcode in found_barcodes_list]
final_adata_long_not_keep_obs_indices = sorted(list(set(adata_long_barcode_index_map.values()).difference(
    final_adata_long_keep_obs_indices)))

In [None]:
# finally, slice
adata_short_mutual_barcodes = adata_short[final_adata_short_keep_obs_indices]
adata_short_mutual_barcodes_genes = adata_short_mutual_barcodes[:, final_adata_short_mutual_keep_var_indices]
adata_long_empty_barcodes = adata_long[final_adata_long_not_keep_obs_indices]
adata_long_mutual_barcodes = adata_long[final_adata_long_keep_obs_indices]
adata_long_mutual_genes = adata_long[:, final_adata_long_mutual_keep_var_indices]
adata_long_mutual_barcodes_genes = adata_long_mutual_barcodes[:, final_adata_long_mutual_keep_var_indices]
adata_long_empty_barcodes_mutual_genes = adata_long_empty_barcodes[:, final_adata_long_mutual_keep_var_indices]

In [None]:
n_mutual_barcodes = len(final_adata_long_keep_obs_indices)
metrics['n_mutual_barcodes'] = n_mutual_barcodes
metrics['pct_mutual_barcodes'] = 100. * (n_mutual_barcodes / n_barcodes_short)

# UMI statistics
metrics['mean_umi_per_barcode_short'] = np.mean(
    np.asarray(adata_short.X.sum(-1)).flatten())
metrics['median_umi_per_barcode_short'] = np.median(
    np.asarray(adata_short.X.sum(-1)).flatten())

metrics['mean_umi_per_barcode_long'] = np.mean(
    np.asarray(adata_long.X.sum(-1)).flatten())
metrics['median_umi_per_barcode_long'] = np.median(
    np.asarray(adata_long.X.sum(-1)).flatten())

metrics['mean_umi_per_mutual_barcode_short'] = np.mean(
    np.asarray(adata_short_mutual_barcodes.X.sum(-1)).flatten())
metrics['median_umi_per_mutual_barcode_short'] = np.median(
    np.asarray(adata_short_mutual_barcodes.X.sum(-1)).flatten())

metrics['mean_umi_per_mutual_barcode_long'] = np.mean(
    np.asarray(adata_long_mutual_barcodes.X.sum(-1)).flatten())
metrics['median_umi_per_mutual_barcode_long'] = np.median(
    np.asarray(adata_long_mutual_barcodes.X.sum(-1)).flatten())

metrics['mean_umi_per_empty_barcode_long'] = np.mean(
    np.asarray(adata_long_empty_barcodes.X.sum(-1)).flatten())
metrics['median_umi_per_empty_barcode_long'] = np.median(
    np.asarray(adata_long_empty_barcodes.X.sum(-1)).flatten())

metrics['pct_umis_in_empty_barcodes_long'] = 100. * adata_long_empty_barcodes.X.sum() / adata_long.X.sum()
metrics['pct_umis_in_empty_barcodes_mutual_genes_long'] = 100. * adata_long_empty_barcodes_mutual_genes.X.sum() / adata_long_mutual_genes.X.sum()

In [None]:
for col in adata_short_obs_col_propagate:
    try:
        adata_long_mutual_barcodes.obs[col] = adata_short_mutual_barcodes.obs[col].values.copy()
        adata_long_mutual_barcodes_genes.obs[col] = adata_short_mutual_barcodes_genes.obs[col].values.copy()
    except:
        log_info(f'WARNING: Could not propagate {col}!')
    
for col in adata_short_obsm_col_propagate:
    try:
        adata_long_mutual_barcodes.obsm[col] = adata_short_mutual_barcodes.obsm[col].copy()
        adata_long_mutual_barcodes_genes.obsm[col] = adata_short_mutual_barcodes_genes.obsm[col].copy()
    except:
        log_info(f'WARNING: Could not propagate {col}!')

In [None]:
# save
adata_long_mutual_barcodes.write(harmonized_long_adata_h5_path)
adata_short_mutual_barcodes.write(harmonized_short_adata_h5_path)
adata_long_mutual_barcodes_genes.write(harmonized_long_adata_mutual_genes_h5_path)
adata_short_mutual_barcodes_genes.write(harmonized_short_adata_mutual_genes_h5_path)

## Concordance between short and long adata total GEX over mutual genes

In [None]:
adata_short = sc.read(harmonized_short_adata_mutual_genes_h5_path)
adata_long = sc.read(harmonized_long_adata_mutual_genes_h5_path)

In [None]:
# highest expressed genes
with plt.rc_context():
    sc.pl.highest_expr_genes(adata_short, n_top=50, show=False)
    plt.savefig(
        os.path.join(
            output_path,
            f'{output_prefix}.highly_expressed_genes.short.png'),
        bbox_inches="tight")

In [None]:
new_adata_long_index = list(f'{gs} ({teq})' for gs, teq in zip(adata_long.var['gene_names'], adata_long.var['transcript_eq_classes']))
adata_long.var['genes_names_teq'] = new_adata_long_index
adata_long.var.set_index('genes_names_teq', inplace=True, drop=True)

In [None]:
# highest expressed isoforms
with plt.rc_context():
    sc.pl.highest_expr_genes(adata_long, n_top=50, show=False)
    plt.savefig(
        os.path.join(
            output_path,
            f'{output_prefix}.highly_expressed_genes.long.png'),
        bbox_inches="tight")

In [None]:
total_tx_expr_long = np.asarray(adata_long.X.sum(0)).flatten()
total_gene_expr_short = np.asarray(adata_short.X.sum(0)).flatten()

In [None]:
short_gene_ids = list(map(drop_version, adata_short.var[ADATA_SHORT_GENE_IDS_COL].values))
long_gene_ids = list(map(drop_version, adata_long.var[ADATA_LONG_GENE_IDS_COL].values))

In [None]:
from itertools import groupby
from operator import itemgetter

gene_id_to_tx_indices_map = dict()
for g in groupby(enumerate(long_gene_ids), key=itemgetter(1)):
    gene_id = g[0]
    tx_indices = list(map(itemgetter(0), g[1]))
    gene_id_to_tx_indices_map[gene_id] = tx_indices
    
total_gene_expr_long = []
for gene_id in short_gene_ids:
    total_gene_expr_long.append(np.sum(total_tx_expr_long[gene_id_to_tx_indices_map[gene_id]]))
total_gene_expr_long = np.asarray(total_gene_expr_long)

In [None]:
total_gene_expr_short_tpm = 1_000_000 * total_gene_expr_short / np.sum(total_gene_expr_short)
total_gene_expr_long_tpm = 1_000_000 * total_gene_expr_long / np.sum(total_gene_expr_long)

In [None]:
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)

In [None]:
import matplotlib.ticker as tck
from sklearn.metrics import r2_score

fig, ax = plt.subplots(figsize=(4, 4))

ax.plot([1e-1, 1e5], [1e-1, 1e5], '--', lw=1, color='black')
ax.scatter(total_gene_expr_short_tpm, total_gene_expr_long_tpm, s=1, alpha=0.2, color='gray')
r2 = r2_score(np.log1p(total_gene_expr_short_tpm), np.log1p(total_gene_expr_long_tpm))
ax.text(0.15, 3e4, f'$R^2$ = {r2:.2f}', fontsize=10)
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xticks([1e-1, 1e1, 1e3, 1e5])
ax.set_yticks([1e-1, 1e1, 1e3, 1e5])
ax.xaxis.set_minor_locator(tck.AutoMinorLocator())
ax.yaxis.set_minor_locator(tck.AutoMinorLocator())
ax.set_xlim((1e-1, 1e5))
ax.set_ylim((1e-1, 1e5))
ax.set_aspect('equal')
ax.set_xlabel('NGS total expression (TPM)')
ax.set_ylabel('MAS-seq total expression (TPM)')

fig.tight_layout()

plt.savefig(
    os.path.join(
        output_path,
        f'{output_prefix}.gex.short.long.concordance.png'),
    bbox_inches="tight")