# utils

> Fill in a module description here

In [None]:
#| default_exp adata

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
# NOTE: needed for python 3.10 forward compatibility with scanpy as 
# scanpy uses Iterable which is deprecated in 3.10
import collections.abc
#hyper needs the four following aliases to be done manually.
collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping

In [None]:
#| export
from spot_mark_gene.types import (
    AnnData, AnnDatas, Graph, SeriesLike,
    VAR_HUMAN_TF, VAR_MOUSE_TF,
    VAR_HUMAN_ENSEMBLE_ID, VAR_MOUSE_ENSEMBLE_ID,
    LAYER_PRENORM, LAYER_DETECTED,
    LAYER_SCALED_NORMALIZED, EMB_MAGIC,
    EMB_PCA, EMB_PCA_HVG,
    EMB_PHATE, EMB_PHATE_HVG,
    CUTOFF_KIND, CUTOFF_SHORTHAND_TO_OBS_KEYS,
    CutoffSpecification, CutoffSpecifications,
    VAR_GENE_SYMBOL, VAR_GENE_IDS,
    OBS_DOUBLET_SCORES, OBS_PREDICTED_DOUBLETS,
    VAR_MITO
)

In [None]:
#| export
import os
import copy

from typing import TypeAlias, List, Sequence, Tuple

import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
import scrublet as scr
import scipy
import graphtools as gt
import phate
import magic

In [None]:
#| export
def add_gene_symbols_to_adata(adata:AnnData) -> AnnData:
    adata.var_names_make_unique()
    adata.var[VAR_GENE_SYMBOL] = adata.var_names
    return adata

def add_gene_ids_to_adata(adata:AnnData) -> AnnData:
    adata.var_names = adata.var[VAR_GENE_IDS]
    return adata

def remove_mitochondrial_genes(adata:AnnData) -> AnnData:
    adata = adata[:, ~adata.var[VAR_MITO]]
    return adata

def score_doublets(adata:AnnData, plot:bool=False) -> AnnData:   
    scrub = scr.Scrublet(adata.X)
    adata.obs[OBS_DOUBLET_SCORES], adata.obs[OBS_PREDICTED_DOUBLETS] =\
        scrub.scrub_doublets()
    if plot:
        scrub.plot_histogram()
    return adata


In [None]:
#| export
def add_gene_annotations(
    adata:AnnData,
    annotation_file:str
) -> AnnData:
    # Load gene annotation information (extracted from bioconductor)    
    gene_annotation = pd.read_csv(
        annotation_file, index_col=None, header=0
    ).astype(str)

    assert hasattr(gene_annotation, 'Ensembl')
    gene_annotation.index = list(gene_annotation.Ensembl)

    # Add to AnnData object
    adata.var = pd.concat(
        [adata.var, gene_annotation], axis=1, join='inner'
    )
    adata.var.index = list(adata.var[VAR_GENE_SYMBOL])
    # Enforce uniqueness
    adata.var_names_make_unique()
    return adata

In [None]:
#| export
from spot_mark_gene.utils import (
    time_to_num_from_idx_to_time
)
def combine_timepoints(
    *adatas:AnnDatas, 
    idx_to_time:dict,
    print_counts:bool=False
):
    '''
    Examples:

        idx_to_time = {
            '0': '12hr', 
            '1': '18hr', 
            '2': '24hr'
        }

        time_to_num = {
            '12hr': '12', 
            '18hr': '18', 
            '24hr': '24'
        }
    '''
    time_to_num = time_to_num_from_idx_to_time(idx_to_time)

    adata = ad.concat(
        [*adatas], 
        index_unique="_", merge="same", join='outer'
    )

    adata.obs['batch'] = adata.obs.index.astype(str).str[-1]

    adata.obs['batch'] = adata.obs['batch']\
        .replace(idx_to_time)
    
    adata.obs['timepoint'] = adata.obs['batch']\
        .replace(time_to_num)

    if print_counts:
        print(adata.obs.batch.value_counts())
    return adata

In [None]:
#| export 
def calc_qc_stats(adata:AnnData) -> AnnData:
    # Calculate QC stats
    adata.var["mito"] = adata.\
            var_names.str.startswith("mt-")
    
    adata.var['ribo'] = adata.\
            var_names.str.startswith(("rps","rpl"))

    sc.pp.calculate_qc_metrics(
        adata, qc_vars=["mito", "ribo"], inplace=True
    )

    adata.obs['log10_total_counts'] = np.log10(
        adata.obs["total_counts"]
    )
    return adata

In [None]:
#| export 
def filter_by_cutoffs(
    adata:AnnData, 
    lower:float=None, 
    upper:float=None,
    obs_key:CUTOFF_KIND='total_counts',
    print_counts:bool=False,      
) -> AnnData:
    assert obs_key is not None
    if lower is not None:
        adata[adata.obs[obs_key] > lower]
    if upper is not None:
        adata[adata.obs[obs_key] < upper]
    if print_counts:
        print(adata.obs.batch.value_counts())
    return adata

def apply_filter_by_cutoffs(
    adata:AnnData, 
    cutoff_specs:CutoffSpecifications,
    print_counts:bool=False   
):
    for spec in cutoff_specs:
        adata = filter_by_cutoffs(
            adata, spec.lower, spec.upper, 
            spec.obs_key, print_counts
        )

    return adata

In [None]:
#| export
def add_prenormalization_layer(adata:AnnData) -> AnnData:
    # Store unnormalised counts
    adata.layers[LAYER_PRENORM] = adata.X
    return adata

def add_gene_detection_layer(adata:AnnData) -> AnnData:
    # Store unnormalised counts
    if LAYER_PRENORM not in adata.layers:
        adata = add_prenormalization_layer(adata)

    # Add layer of gene detection
    adata.layers[LAYER_DETECTED] = scipy.sparse.csr_matrix(
        pd.DataFrame(
        (adata.layers[LAYER_PRENORM].toarray() > 0), 
        columns = adata.var.index, index=adata.obs.index
    ).replace({True: 1, False: 0}))
    return adata

def sqrt_library_size_normalize(adata:AnnData) -> AnnData:
    # Normalise by library size and square-root transform
    adata = adata.copy()
    adata.X = scipy.sparse.csr_matrix(
        sc.transform.sqrt(
            sc.normalize.library_size_normalize(
                adata.X.toarray()
            )
        )
    )
    return adata

In [None]:
#| export
def add_batch_mean_center_layer(adata:AnnData) -> AnnData:
    # Batch mean center before cell cycle scoring
    adata.raw = adata
    adata.X = scipy.sparse.csr_matrix(
        sc.normalize.batch_mean_center(
            adata.X.toarray(), 
            sample_idx = adata.obs['batch']
        )
    )

    adata.layers[LAYER_SCALED_NORMALIZED] = scipy.sparse.csr_matrix(adata.X)
    adata.X = adata.raw.X
    return adata

def score_genes_cell_cycle_with_batch_mean_center_data(
        adata:AnnData,
        s_genes:Sequence[str], 
        g2m_genes:Sequence[str],
) -> AnnData:
    
    sdata = adata.layers[LAYER_SCALED_NORMALIZED]
    # Get normalised counts back instead of mean centered values as pca will mean center
    sc.tl.score_genes_cell_cycle(sdata, s_genes=s_genes, g2m_genes=g2m_genes)
    return adata

def load_human_genes(
    adata:AnnData, filename:str
) -> List[str]:
    '''
    NOTE:
        - uses adata to confirm validity
    '''
    assert hasattr(adata.var, 'HumanGeneSymbol')

    with open(filename, 'r') as f:
        genes = f.readlines()
        genes = [gene.strip() for gene in genes]
        genes = adata.var.index[adata.var.HumanGeneSymbol.isin(genes)]
        return genes

In [None]:
#| export 
def select_hvg_per_batch(
    adata:AnnData,
    hvg_kwargs:dict=dict(cutoff=None, percentile=90)
) -> AnnData:
    # Select highly variable genes from any batch
    hvg_all = []
    for batch in adata.obs.batch.unique():
        normalised, hgv_vars = sc.select.highly_variable_genes(
            adata[adata.obs.batch == batch].X.toarray(), 
            adata[adata.obs.batch == batch].var.index, 
            **hvg_kwargs
        )
        hvg_all.extend(hgv_vars)
        adata.var[f'highly_variable_{batch}'] = adata.var.index.isin(hgv_vars)
        del normalised
        print(f"Unique HVGs after {batch} {len(np.unique(np.array(hvg_all)))}")
        
    adata.var['highly_variable'] = adata.var.index.isin(hvg_all)
    return adata

In [None]:
#| export 
def add_tf_annotations_from_csv(
    adata:AnnData, filename:str,
    tf_key:str, ensemble_key:str,
    print_counts:bool=False
) -> AnnData:
    assert hasattr(adata.var, ensemble_key)
    df_tfs = pd.read_csv(filename, index_col=None, header=0).astype(str)
    
    adata.var[tf_key] = adata.var[ensemble_key]\
        .isin(df_tfs[ensemble_key])

    if print_counts:
        print(adata.var[tf_key].value_counts())
    return adata

def add_human_tfs_from_csv(
    adata:AnnData, filename:str,
    print_counts:bool=False
) -> AnnData:
    return add_tf_annotations_from_csv(
        adata, filename, VAR_HUMAN_TF,
        VAR_HUMAN_ENSEMBLE_ID, print_counts
    )

def add_mouse_tfs_from_csv(
    adata:AnnData, filename:str,
    print_counts:bool=False
) -> AnnData:    
    return add_tf_annotations_from_csv(
        adata, filename, VAR_MOUSE_TF,
        VAR_MOUSE_ENSEMBLE_ID, print_counts
    )

In [None]:
#| export
from scipy.stats import zscore
def zscore_markers_in_layer(
    adata:AnnData,
    markers:List[str],
    obs_key:str='Markers_zscore',
    layer_key:str=EMB_MAGIC,    
) -> AnnData:    
    # Score cells based on select marker expression (sum of zscores of smoothed counts)
    col_subset = adata.var.index.isin(markers)
    df_markers = pd.DataFrame(
        adata.layers[layer_key][:, col_subset].toarray(), 
        columns = adata.var.index[col_subset],
        index = adata.obs.index
    )
    df_markers.apply(zscore)
    adata.obs[obs_key] = df_markers.sum(axis=1)
    return adata

def subset_markers(
    adata:AnnData,
    obs_key:str='Markers_cell',
    score_key:str='Markers_zscore',
    lower:float=2.2,
    upper:float=None,
    marker_name:str='marker',
    other_name:str='other'
) -> AnnData:
    u_cut = pd.Series([True for t in adata.obs])
    l_cut = pd.Series([True for t in adata.obs])
    if upper is not None:
        u_cut = (adata.obs[score_key] < upper)

    if lower is not None:
        l_cut = (lower < adata.obs[score_key])

    adata.obs[obs_key] = (u_cut & l_cut).replace({True: marker_name, False: other_name})
    return adata


In [None]:
#| export
def run_pca(
    adata:AnnData,
    pca_kwargs:dict=dict(n_components=100),
    plot_scree:bool=False,
    emb_key:str=EMB_PCA,
    col_subset:SeriesLike=None
) -> AnnData:
    # Compute PCs for initial cell graph
    pca_kwargs['return_singular_values'] = True
    pca_kwargs['seed'] = 3

    if col_subset is not None:
        x = adata[:, col_subset].X.toarray()
    else:
        x = adata.X.toarray()

    pcs, svs = sc.reduce.pca(x, **pca_kwargs)
    adata.obsm[emb_key] = pcs
    if plot_scree:
        sc.plot.scree_plot(svs, cumulative=False)
    return adata

def run_pca_on_hvg(
    adata:AnnData,
    pca_kwargs:dict=dict(n_components=100),
    plot_scree:bool=False,
) -> AnnData:
    return run_pca(
        adata, pca_kwargs, plot_scree,
        EMB_PCA_HVG, adata.var.highly_variable
    )

In [None]:
#| export 
def run_phate_using_g(
    adata:AnnData,    
    g: Graph = None,
    phate_kwargs:dict = dict(t=70),
    g_kwargs:dict = dict(knn=10),    
    emb_key:str=EMB_PHATE, 
) -> Tuple[AnnData, Graph]:  
    # Make initial cellwise graph with HVGS (auto t=46)
    g_kwargs['random_state'] = 3
    g_kwargs['n_pca'] = None
    phate_kwargs['random_state'] = 3 

    if g is None:        
        pca_key = emb_key.replace('phate', 'pca')        
        print((
            f'g is None. Will attempt to calculate with'
            f' PCA stored in adata.obsm{pca_key}.'
        ))

        if pca_key not in adata.obsm:
            raise ValueError(f'{pca_key} not in adata.obsm')
        
        g = gt.Graph(
            adata.obsm[pca_key], n_pca=None, 
            **g_kwargs
        )

    phate_op = phate.PHATE(**phate_kwargs)
    data_phate = phate_op.fit_transform(g)
    adata.obsm[emb_key] = data_phate
    return adata, g

def run_phate_on_hvg(
    adata:AnnData,    
    g: Graph = None,
    phate_kwargs:dict = dict(t=70),
    g_kwargs:dict = dict(knn=10),    
    emb_key:str=EMB_PHATE_HVG,    
) -> Tuple[AnnData, Graph]:
    return run_phate_using_g(
        adata, g, phate_kwargs, g_kwargs, emb_key
    )

In [None]:
#| export
def run_magic(
    adata:AnnData, g:Graph,
    knn_max:int = 60
) -> AnnData:
    G = copy.deepcopy(g)
    G.knn_max = knn_max
    G.data = adata.to_df()
    G.data_nu = adata.to_df()
    magic_op = magic.MAGIC().fit(adata.to_df(), graph=G)
    data_magic = magic_op.transform(genes='all_genes')
    adata.layers['X_magic'] = scipy.sparse.csr_matrix(data_magic)
    return adata

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()