In [None]:
%pip install scanpy
%pip install anndata
%pip install MuData
%pip install loompy

In [None]:
from mudata import MuData
import numpy as np
from tqdm import tqdm
import pandas as pd
import anndata
import scanpy
import csv
import sys
import loompy
import logging
import os
from matplotlib import pyplot as plt
import plotly.graph_objects as go

  def twobit_to_dna(twobit: int, size: int) -> str:
  def dna_to_twobit(dna: str) -> int:
  def twobit_1hamming(twobit: int, size: int) -> List[int]:


In [None]:
# Configure logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
# dont judge me too much...
import warnings
warnings.simplefilter(action='ignore', category=Warning)

In [None]:
# function reads the loomfile downloaded from Tapestri portal
def read_tapestri_loom(filename):
    """
    Read data from MissionBio's formatted loom file.

    Parameters
    ----------
    filename : str
        Path to the loom file (.loom)

    Returns
    -------
    anndata.AnnData
        An anndata object with the following layers:
        adata.X: GATK calls
        adata.layers['e']: reads with evidence of mutation
        adata.layers['no_e']: reads without evidence of mutation
    """
    loom_file = loompy.connect(filename)

    variant_names, amplicon_names, chromosome, location = (
        loom_file.ra['id'], loom_file.ra['amplicon'], loom_file.ra['CHROM'], loom_file.ra['POS'])

    barcodes = [barcode.split('-')[0] for barcode in loom_file.ca['barcode']]
    adata = anndata.AnnData(np.transpose(loom_file[:, :]), dtype=loom_file[:, :].dtype)
    adata.layers['e'] = np.transpose(loom_file.layers['AD'][:, :])
    adata.layers['no_e'] = np.transpose(loom_file.layers['RO'][:, :])
    adata.var_names = variant_names
    adata.obs_names = barcodes
    adata.varm['amplicon'] = amplicon_names
    adata.varm['chrom'] = chromosome
    adata.varm['loc'] = location

    loom_file.close()

    return adata

def read_HyPR_file(HyPR_file, byprobe=False):

    HyPR_array = np.loadtxt(HyPR_file, dtype=str, skiprows=1)
    barcodes = HyPR_array[:, 0]
    with open(HyPR_file, 'r') as f:
        first_line = f.readline().strip()
        values = first_line.split('\t')
        var_names = np.array(values, dtype=str)[:]
    X = np.asarray(HyPR_array[:, 1:], dtype=int)
    if np.size(var_names) != X.shape[1]: # sloppy I shouldn't need this
      var_names = var_names[1:]
    df = pd.DataFrame(X, columns=var_names)

    if byprobe == True:
      adata_hypr = anndata.AnnData(X)
      adata_hypr.obs_names = barcodes
      try:
        adata_hypr.var_names = var_names
      except:
        adata_hypr.var_names = var_names[1:]

    if byprobe == False:
      df.columns = df.columns.str.split('_').str[0]
      collapsed_df = df.groupby(axis=1, level=0).sum()
      collapsed_X = collapsed_df.to_numpy()
      collapsed_var_names = collapsed_df.columns.to_numpy()

      adata_hypr = anndata.AnnData(collapsed_X)
      adata_hypr.obs_names = barcodes

      try:
        adata_hypr.var_names = collapsed_var_names
      except:
        adata_hypr.var_names = collapsed_var_names[1:]

    return adata_hypr

# Function for finding the intersecting barcodes between the modalities
def find_intersecting(mdata):
    """
    Find intersecting barcodes and add 'intersecting' to obs with boolean values.

    Parameters
    ----------
    mdata : object
        The mudata object.

    Raises
    ------
    AssertionError
        If the length of mod_names is not equal to 2.
    """

    mod_names = list(mdata.mod.keys())
    assert len(mod_names) == 2, 'Function not implemented for mod_names with length different from 2'

    obs_1, obs_2 = mdata[mod_names[0]].obs_names, mdata[mod_names[1]].obs_names

    cmn_barcodes, idx_1, idx_2 = np.intersect1d(obs_1, obs_2, return_indices=True)

    # Initializing 'intersecting' columns to False
    mdata[mod_names[0]].obs['intersecting'] = False
    mdata[mod_names[1]].obs['intersecting'] = False

    # Updating 'intersecting' columns where True
    mdata[mod_names[0]].obs['intersecting'].iloc[idx_1] = True
    mdata[mod_names[1]].obs['intersecting'].iloc[idx_2] = True

    logger.info(f"Found {idx_1.size} intersecting barcodes")


def annotate_genotype(mdata, variants, window=4, obs_key='mutant_type',
                      genotype_key='tapestri', phenotype_key='HyPR',
                      ignore_bystanders=False, write_all_as=False,
                      ignore_mixed=False, only_type=None):

    adata = mdata.mod[genotype_key]

    # loop through the variants in the variant list
    for variant in variants:

        if obs_key in mdata.mod[phenotype_key].obs:
          pass
        else:
          mdata.mod[phenotype_key].obs[obs_key] = 'unannotated'
          mdata.mod[genotype_key].obs[obs_key] = 'unannotated'

        try:
            idx = adata.var_names.get_loc(variant)
        except KeyError:
            logger.info(str(variant) + ' was not found in the genotype modality anndata')
            continue

        # identify which cells have the variant
        bkg_idx, het_idx, hom_idx = [np.flatnonzero(adata.X[:, idx] == i) for i in (0, 1, 2)]

        if only_type == 'het':
          hom_idx = None
        if only_type == 'hom':
          het_idx = None

        if ignore_bystanders == True:

          if het_idx is not None:
            cmn_barcodes, idx_genotype, idx_phenotype = np.intersect1d(mdata.mod[genotype_key].obs_names[het_idx],
                                                                    mdata.mod[phenotype_key].obs_names, return_indices=True)

            if write_all_as == False:
              mdata.mod[phenotype_key].obs[obs_key][idx_phenotype] = str(variant) + '_het'
              mdata.mod[genotype_key].obs[obs_key][het_idx]= str(variant) + '_het'

            else:
              mdata.mod[phenotype_key].obs[obs_key][idx_phenotype] = write_all_as

          if hom_idx is not None:
            cmn_barcodes, idx_genotype, idx_phenotype = np.intersect1d(mdata.mod[genotype_key].obs_names[hom_idx],
                                                                   mdata.mod[phenotype_key].obs_names, return_indices=True)
            if write_all_as == False:
              mdata.mod[phenotype_key].obs[obs_key][idx_phenotype]= str(variant) + '_hom'
              mdata.mod[genotype_key].obs[obs_key][hom_idx]= str(variant) + '_hom'
            else:
              mdata.mod[phenotype_key].obs[obs_key][idx_phenotype] = write_all_as

        else:

          # identify indices of nearby SNPs that are within the bystander window
          chromosome, location = adata.varm['chrom'][idx], adata.varm['loc'][idx]
          nearby_ = np.where(adata.varm['chrom'] == chromosome)
          idx_ = np.where(np.logical_and(np.asarray(adata.varm['loc'], dtype=int) < int(location + window), np.asarray(adata.varm['loc'], dtype=int) > int(location - window)))
          nearby = np.intersect1d(nearby_, idx_)

          mut_vals = adata.X[:, nearby]
          mut_vals[mut_vals == 3] = 0
          mut_vals[mut_vals > 1] = 1
          sum_muts = np.sum(mut_vals, axis=1)

          if het_idx is not None:

            pure_het_idx = np.intersect1d(np.argwhere(sum_muts == 1), het_idx)
            not_pure_het_idx = np.intersect1d(np.argwhere(sum_muts != 1), het_idx)
            mdata.mod[genotype_key].obs[obs_key][pure_het_idx] = str(variant) + '_het_pure'
            mdata.mod[genotype_key].obs[obs_key][not_pure_het_idx] = str(variant) + '_het_bystander'

            cmn_barcodes, idx_genotype, idx_phenotype = np.intersect1d(mdata.mod[genotype_key].obs_names[not_pure_het_idx],
                                                                      mdata.mod[phenotype_key].obs_names, return_indices=True)
            mdata.mod[phenotype_key].obs[obs_key][idx_phenotype] = str(variant) + '_het_bystander'
            logger.info(f"Annotated {idx_genotype.size} heterozygous mutants with bystanders")


            cmn_barcodes, idx_genotype, idx_phenotype = np.intersect1d(mdata.mod[genotype_key].obs_names[pure_het_idx],
                                                                      mdata.mod[phenotype_key].obs_names, return_indices=True)
            mdata.mod[phenotype_key].obs[obs_key][idx_phenotype] = str(variant) + '_het_pure'
            logger.info(f"Annotated {idx_genotype.size} pure heterozygous mutants")

          if hom_idx is not None:

            pure_hom_idx = np.intersect1d(np.argwhere(sum_muts == 1), hom_idx)
            not_pure_hom_idx = np.intersect1d(np.argwhere(sum_muts != 1), hom_idx)

            mdata.mod[genotype_key].obs[obs_key][pure_hom_idx] = str(variant) + '_hom_pure'
            mdata.mod[genotype_key].obs[obs_key][not_pure_hom_idx] = str(variant) + '_hom_bystander'

            cmn_barcodes, idx_genotype, idx_phenotype = np.intersect1d(mdata.mod[genotype_key].obs_names[not_pure_hom_idx],
                                                                    mdata.mod[phenotype_key].obs_names, return_indices=True)
            mdata.mod[phenotype_key].obs[obs_key][idx_phenotype] = str(variant) + '_hom_bystander'
            logger.info(f"Annotated {idx_genotype.size} homozygous mutants with bystanders")

            cmn_barcodes, idx_genotype, idx_phenotype = np.intersect1d(mdata.mod[genotype_key].obs_names[pure_hom_idx],
                                                                    mdata.mod[phenotype_key].obs_names, return_indices=True)
            mdata.mod[phenotype_key].obs[obs_key][idx_phenotype] = str(variant) + '_hom_pure'
            logger.info(f"Annotated {idx_genotype.size} pure homozygous mutants")


    if ignore_mixed == False:
      # first, lets find the cells that have mixed edits of the intended variants
      indices = []
      for variant in variants:
          try:
              indices.append(adata.var_names.get_loc(variant))
          except KeyError:
              logger.info(str(variant) + ' was not found in the genotype modality anndata')
              continue
      idxs = np.asarray(indices).flatten()

      edit_array = np.zeros(adata.X[:, idxs].shape)
      edit_array[adata.X[:, idxs] == 1] = 1
      edit_array[adata.X[:, idxs] == 2] = 1
      mixed_idxs = np.argwhere(np.sum(edit_array, axis=1) >= 2).ravel()

      mdata.mod[genotype_key].obs[obs_key][mixed_idxs] = 'mixed_mutant'
      cmn_barcodes, idx_genotype, idx_phenotype = np.intersect1d(mdata.mod[genotype_key].obs_names[mixed_idxs],
                                                                      mdata.mod[phenotype_key].obs_names, return_indices=True)
      mdata.mod[phenotype_key].obs[obs_key][idx_phenotype]= 'mixed_mutant'

In [None]:
import seaborn as sns
def barcode_rank_plot(adata, minimum=0, xmax=None):

  # Sum the UMIs for each cell
  cell_umi_counts_all = adata.X.sum(axis=1)
  cell_idxs = np.argwhere(cell_umi_counts_all > minimum)
  cell_umi_counts = cell_umi_counts_all[cell_idxs]

  # Convert to numpy array if it's not already
  cell_umi_counts = np.array(cell_umi_counts).flatten()

  # Sort the UMI counts in descending order for the rank plot
  sorted_umi_counts = np.sort(cell_umi_counts)[::-1]

  plt.figure(figsize=(5, 4), dpi=150)
  sns.lineplot(x=range(1, len(sorted_umi_counts) + 1), y=sorted_umi_counts)
  plt.xlabel('Barcode Rank')
  plt.ylabel('UMI Count')
  plt.title('Cell Barcode Rank Plot')
  plt.yscale('log')  # Log scale for better visualization
  if xmax is not None:
    plt.xlim(0, xmax)
  plt.show()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# define where all the files are on the drive
data_dir = '/path/to/tnrc18/data'

file_locations = [['tapestri/HSC_Neu_D1_Day9.cells.loom', 'HyPR/Neu_D1_Day9_HyPR.txt'],
                  ['tapestri/HSC_Neu_D2_Day9.cells.loom', 'HyPR/Neu_D2_Day9_HyPR.txt'],
                  ['tapestri/HSC_Neu_D1_Day17.cells.loom', 'HyPR/Neu_D1_Day17_HyPR.txt'],
                  ['tapestri/HSC_Neu_D2_Day17.cells.loom', 'HyPR/Neu_D2_Day17_HyPR.txt']]

sample_names = ['HSC_Neu_D1_Day9', 'HSC_Neu_D2_Day9', 'HSC_Neu_D1_Day17', 'HSC_Neu_D2_Day17']

# check to make sure all those files exist
for i in tqdm(range(0, len(file_locations))):
  tap_file = data_dir + file_locations[i][0]
  hypr_file = data_dir + file_locations[i][1]
  if os.path.exists(tap_file):
      pass
  else:
    print(tap_file)
  if os.path.exists(hypr_file):
    pass
  else:
    print(hypr_file)

100%|██████████| 6/6 [00:01<00:00,  5.14it/s]


In [None]:
mdata_list = []
for i in tqdm(range(0, len(file_locations))):
    adata_hypr = read_HyPR_file(data_dir + file_locations[i][1], byprobe=False)
    adata_tapestri = read_tapestri_loom(data_dir + file_locations[i][0])
    mdata_list.append(MuData({'HyPR': adata_hypr, 'tapestri': adata_tapestri}))

100%|██████████| 6/6 [03:29<00:00, 34.86s/it]


In [None]:
sns.set_theme()
for i, sample_name in enumerate(sample_names):
  print(sample_name)
  barcode_rank_plot(mdata_list[i].mod['HyPR'])

In [None]:
for i, sample_name in enumerate(sample_names):
  if i not in [2, 3]:
    scanpy.pp.filter_cells(mdata_list[i].mod['HyPR'], min_counts=1000)
  if i in [2, 3]:
    scanpy.pp.filter_cells(mdata_list[i].mod['HyPR'], min_counts=316)

In [None]:
# find intersecting and remove any cells that are not intersecting across the modalities since we wont need them
for mdata in mdata_list:
  find_intersecting(mdata)
  mdata.mod['tapestri'] = mdata.mod['tapestri'][mdata.mod['tapestri'].obs['intersecting'] == True].copy()
  mdata.mod['HyPR'] = mdata.mod['HyPR'][mdata.mod['HyPR'].obs['intersecting'] == True].copy()

In [None]:
# identify the AAVS
chrom, loc, edit_type = [], [], []
variant_names = np.asarray(mdata.mod['tapestri'].var_names.values)
for i in range(0, np.size(variant_names)):
    chrom.append(variant_names[i].split(':')[0])
    loc.append(int(variant_names[i].split(':')[1]))
    edit_type.append(variant_names[i].split(':')[2])

control_editing = []
for i in range(0, len(chrom)):
    if chrom[i] == 'chr19':
        if loc[i] in range(55115745, 55115764):
            if edit_type[i] == "A/G":
                control_editing.append(variant_names[i])
            if edit_type[i] == "T/C":
                control_editing.append(variant_names[i])
        if loc[i] in range(55115752, 55115771):
            if edit_type[i] == "C/T":
                control_editing.append(variant_names[i])
            if edit_type[i] == "G/A":
                control_editing.append(variant_names[i])

In [None]:
variants = ['chr7:5397122:C/T']
for i, mdata in enumerate(mdata_list):
    annotate_genotype(mdata, control_editing, ignore_bystanders=True, ignore_mixed=True, write_all_as='AAVS')
    annotate_genotype(mdata, variants, ignore_bystanders=False, ignore_mixed=False, window=10)

In [None]:
# create a concatenated anndata object for the HyPR
adata_hypr_list = [mdata.mod['HyPR'] for mdata in mdata_list]
adata_tapestri_list = [mdata.mod['tapestri'] for mdata in mdata_list]

for anndata_obj, batch_name in zip(adata_hypr_list, sample_names):
    anndata_obj.obs['batch'] = batch_name

for anndata_obj, batch_name in zip(adata_tapestri_list, sample_names):
    anndata_obj.obs['batch'] = batch_name

concatenated_anndata = anndata.concat(adata_hypr_list)
concatenated_anndata.obs['S'] = concatenated_anndata.obs['batch'].str[-1].astype(int)
concatenated_anndata.obs['stimulation'] = concatenated_anndata.obs['batch'].str[:-1]

In [None]:
concatenated_anndata.layers['raw'] = concatenated_anndata.X.copy()

In [None]:
scanpy.pp.normalize_total(concatenated_anndata, target_sum=1e4)
scanpy.pp.log1p(concatenated_anndata)

In [None]:
scanpy.tl.pca(concatenated_anndata)
scanpy.pp.neighbors(concatenated_anndata)
scanpy.tl.umap(concatenated_anndata, min_dist=0.5, spread=3)

In [None]:
# Check the UMAP
with plt.rc_context({"figure.figsize": (4, 3), "figure.dpi": (400)}):
  scanpy.pl.umap(concatenated_anndata, color='batch', show=False)
  ax = plt.gca()
  for spine in ax.spines.values():
      spine.set_visible(False)
  plt.show()

In [None]:
# lets save this as an .h5ad for easy use in Python
output_path = "where/to/write/h5ad"
concatenated_anndata.write(output_path)