## scMomentum

In [1]:
import anndata
from collections import defaultdict
import copy
import csv
from joblib import Parallel, delayed
from matplotlib import colors
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from matplotlib import style
from mpl_toolkits.mplot3d import Axes3D
import multiprocessing
import networkx as nx
import numba
import numpy as np
import numpy.random as rnd
import os
import pandas as pd
import pickle
import random
from random import choices
import re
import scipy as scp
import scipy.integrate as integrate
from scipy.special import hyp2f1 as hyper
import scipy.stats as stats
from scipy.stats import norm as normal
import scvelo as scv
from scvelo.tools.velocity_embedding import quiver_autoscale,velocity_embedding
import seaborn as sns
from sklearn import preprocessing
from sklearn.cluster import Birch
import sklearn.decomposition as skd
from sklearn.neighbors import NearestNeighbors
import string
import umap
scv.settings.verbosity = 0

### Utilities 



In [2]:
def load_adata(file):
    
    # Input: 
    #  - file = path to file containing a pickle object
    # Returns:
    #  - loaded data in the original format
    
    with open(file, 'rb') as inF:
        obj = pickle.load(inF)
        
        return(obj)
    
def save_adata(obj, filename):
    
    # Input:
    #   - obj = python object
    #   - filename = path to save object
    # Returns:
    #   - nothing, just saves the object into the specified file 
    
    with open(filename, 'wb') as output: 
        pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL)
    
def unique(list1): 
    
    # Input:
    #  - list1 = python list
    # Returns:
    #  - numpy array with unique elements in the list
    
    x = np.array(list1) 
    return(np.unique(x))

def pad_matrix(n,m):
    A=scp.sparse.csr_matrix((n*m,n*m),dtype=float)
    for i in range(n*m):
        
        if i==0:
            A[i,i+1]=1/3
            A[i,i+m]=1/3
            A[i,i+m+1]=1/3
        elif i==m-1:
            A[i,i-1]=1/3
            A[i,i+m-1]=1/3
            A[i,i+m]=1/3
        elif i==(n-1)*m:
            A[i,i-m]=1/3
            A[i,i-m+1]=1/3
            A[i,i+1]=1/3
        elif i==n*m-1:
            A[i,i-m-1]=1/3
            A[i,i-m]=1/3
            A[i,i-1]=1/3
        elif i<m:
            A[i,i-1]=1/5
            A[i,i+1]=1/5
            A[i,i+m-1]=1/5
            A[i,i+m]=1/5
            A[i,i+m+1]=1/5
        elif i>(n-1)*m:
            A[i,i-m-1]=1/5
            A[i,i-m]=1/5
            A[i,i-m+1]=1/5
            A[i,i-1]=1/5
            A[i,i+1]=1/5
        elif i%m==0:
            A[i,i-m]=1/5
            A[i,i-m+1]=1/5
            A[i,i+1]=1/5
            A[i,i+m]=1/5
            A[i,i+m+1]=1/5
        elif i%m==(-1%m):
            A[i,i-m-1]=1/5
            A[i,i-m]=1/5
            A[i,i-1]=1/5
            A[i,i+m-1]=1/5
            A[i,i+m]=1/5
        else:
            A[i,i-m-1]=1/8
            A[i,i-m]=1/8
            A[i,i-m+1]=1/8
            A[i,i-1]=1/8
            A[i,i+1]=1/8
            A[i,i+m-1]=1/8
            A[i,i+m]=1/8
            A[i,i+m+1]=1/8
    return A

def soften(Z,it=2,ep=None):
    if ep:
        err=ep+1
        while err>ep:
            Zp=Z
            Z=padMatrix(Z.shape[0],Z.shape[1]).dot(np.ravel(Z)).reshape(Z.shape)
            err=(Z-Zp).max()
            print(err)
        return Z
    else:
        for i in range(it):
            Z=padMatrix(Z.shape[0],Z.shape[1]).dot(np.ravel(Z)).reshape(Z.shape)
        return Z

def sigmoide(cells,th=None):
    if th is None:
        th = np.mean(cells,0)
    
    sig = (cells-th)>0
    
    sig = sig*2-1
    
    return sig

def project_velocity_on_grid(
    adata,
    cells_2d,
    density=None,
    smooth=None,
    n_neighbors=None,
    min_mass=None,
    autoscale=True,
    adjust_for_stream=False,
    cutoff_perc=None,
):
    
    T=scv.tools.transition_matrix(adata)
    T.setdiag(0)
    T.eliminate_zeros()
    V_emb=np.zeros(cells_2d.shape)
    densify = adata.n_obs < 1e4
    TA = T.A if densify else None
    for i in range(cells_2d.shape[0]):
        indices = T[i].indices
        dX = cells_2d[indices] - cells_2d[i, None]  # shape (n_neighbors, 2)
        #if not retain_scale: dX /= norm(dX)[:, None]
        dX /= scv.utils.norm(dX)[:,None]
        dX[np.isnan(dX)] = 0  # zero diff in a steady-state
        probs = TA[i, indices] if densify else T[i].data
        V_emb[i] = probs.dot(dX) - probs.mean() * dX.sum(0)  # probs.sum() / len(indices)
    
    X_emb = cells2d
    # remove invalid cells
    idx_valid = np.isfinite(X_emb.sum(1) + V_emb.sum(1))
    X_emb = X_emb[idx_valid]
    V_emb = V_emb[idx_valid]
    # prepare grid
    n_obs, n_dim = X_emb.shape
    density = 1 if density is None else density
    smooth = 0.5 if smooth is None else smooth
    grs = []
    for dim_i in range(n_dim):
        m, M = np.min(X_emb[:, dim_i]), np.max(X_emb[:, dim_i])
        m = m - 0.01 * np.abs(M - m)
        M = M + 0.01 * np.abs(M - m)
        gr = np.linspace(m, M, int(50 * density))
        grs.append(gr)
    meshes_tuple = np.meshgrid(*grs)
    X_grid = np.vstack([i.flat for i in meshes_tuple]).T
    # estimate grid velocities
    if n_neighbors is None:
        n_neighbors = int(n_obs / 50)
    nn = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=-1)
    nn.fit(X_emb)
    dists, neighs = nn.kneighbors(X_grid)
    scale = np.mean([(g[1] - g[0]) for g in grs]) * smooth
    weight = normal.pdf(x=dists, scale=scale)
    p_mass = weight.sum(1)
    V_grid = (V_emb[neighs] * weight[:, :, None]).sum(1)
    V_grid /= np.maximum(1, p_mass)[:, None]
    if min_mass is None:
        min_mass = 1
    if adjust_for_stream:
        X_grid = np.stack([np.unique(X_grid[:, 0]), np.unique(X_grid[:, 1])])
        ns = int(np.sqrt(len(V_grid[:, 0])))
        V_grid = V_grid.T.reshape(2, ns, ns)
        mass = np.sqrt((V_grid ** 2).sum(0))
        min_mass = 10 ** (min_mass - 6)  # default min_mass = 1e-5
        min_mass = np.clip(min_mass, None, np.max(mass) * 0.9)
        cutoff = mass.reshape(V_grid[0].shape) < min_mass
        if cutoff_perc is None:
            cutoff_perc = 5
        length = np.sum(np.mean(np.abs(V_emb[neighs]), axis=1), axis=1).T
        length = length.reshape(ns, ns)
        cutoff |= length < np.percentile(length, cutoff_perc)
        V_grid[0][cutoff] = np.nan
    else:
        min_mass *= np.percentile(p_mass, 99) / 100
        X_grid, V_grid = X_grid[p_mass > min_mass], V_grid[p_mass > min_mass]
        if autoscale:
            V_grid /= 3 * quiver_autoscale(X_grid, V_grid)
    return X_grid, V_grid

### Methods

In [None]:
def select_genes(v,xm,ng,mode):
    
    # Description:
    # Selects genes used for network inference based on the chosen method.
    # Input:
    #   - v = velocity matric for the cells in a specific cluster
    #   - xm = expression matrix for the cells in a specific cluster
    #   - ng = number of genes to choose
    #   - mode = selection method. Options:
    #        * abstop = gene ranking based on the mean (across cells) absolute value of velocity 
    #        * ran = gene set selected at random
    #        * top = gene ranking based on the mean (across cells) value of velocity (including sign)
    #        * minstd = gene ranking based on increasing standard deviation of velocity across cells
    #        * maxstd = gene ranking based on decreasing standard deviation of velocity across cells
    #        * vrank = gene ranking based on differential velocity across clusters
    #        * mark = cluster marker genes, specified in the object clusterMarkers, this options requires argument 'c' to be the name of the cluster
    #        * leastvar = gene ranking basedon increasing standard deviation of expression across cells
    #        * topvar = gene ranking basedon decreasing standard deviation of expression across cells
    # Returns:
    #   - pos = list with genes selected (based on the cluster matrix dimensions) 
    
    ng = int(ng)
    
    if(mode=='abstop'):
        pos = v.dropna().abs().mean(0).sort_values(ascending=False)[0:ng].index.tolist()
    elif(mode=='ran'):
        n = len(v.dropna().columns)
        if(ng>=n):
            pos = v.dropna().columns.tolist()
        else:
            pos = v.dropna().sample(ng,axis=1).abs().mean(0).index.tolist()
    elif(mode=='top'):
        pos = v.dropna().mean(0).sort_values(ascending=False)[0:ng].index.tolist()
    elif(mode=='maxstd'):
        pos = v.dropna().std(0).sort_values(ascending=False).index.tolist()[0:ng]
    elif(mode == 'vrank'):
        pos = V_rank[c][0:ng].tolist()
    elif(mode=='topvar'):        
        pos = xm.dropna().std(0).sort_values(ascending=False).index.tolist()[0:ng]
    elif(mode=='highexp'):
        pos = xm.dropna().mean(0).sort_values(ascending=False).index.tolist()[0:ng]

    return(pos)

def centroid_cells(X_arr,qu=0.65):
    
    # Input:
    #   - X_arr = numpy array, gene expression data of a cluster
    #   - qu = quantile to filter distance of cells, default 0.65
    # Output:
    #  - list with index of cells 
    centroid = np.array([X_arr.mean(1)]).T
    dists = np.sqrt(np.sum(np.power(X_arr-centroid,2),axis=0))
    quant = np.quantile(dists,qu)
    wh = np.where(dists<=quant)[0]
    
    return(list(wh))

def get_cluster_data(adata,cluster,ng,mode,add_tf='no'):
    
    # Description:
    # Extracts from the full data set all the information required to predict
    # the network for each cluster
    # Input:
    #   - X = numpy array, original expression matrix
    #   - V = numpy array, original velocity matrix
    #   - G = numpy diagonal matrix, gamma matrix of all the genes
    #   - G_annot = data fram with gene annotations
    #   - ind = index to select cells from the specified cluster 
    #   - ng = int, number of genes to choose
    #   - mode = str, gene selection method 
    #   - cluster = str, name of cluster 
    
    ind = adata.obs[clustcol] == cluster
    
    #Velocity matrix
    V_c = adata.layers["velocity"][ind.values,:]
    V_non_nan = np.where(np.logical_not(np.isnan(V_c.sum(axis=0))))[0].tolist()
    genes_valid = adata.var[['velocity_genes']].iloc[V_non_nan,:].index
    V_c = pd.DataFrame(V_c[:,V_non_nan],columns=genes_valid)
    
    #Expression matrix
    X_c = pd.DataFrame(adata.layers['spliced'][:,V_non_nan][ind.values,:].todense(),columns=genes_valid)

    #Select top n genes and filter matrices
    fg = grn_selectGenes(V_c,X_c,ng,mode=mode)
    
    if add_tf=='yes':
        fg = utl_unique(fg + grn_selectTFs(adata,c))
    
    #Filter matrices for model reconstruction 
    
    genes = [g for g in fg if g in X_c.columns]
    X_c_f = X_c.loc[:,genes] # cell by gene expression matrix
    V_c_f = V_c.loc[:,genes] # cell by gene velocity matrix
    G_f = np.diag(adata.var.fit_gamma[V_non_nan].loc[genes,]) # gene by gene diagonal gamma matrix

    X_d = pd.DataFrame(np.array(X_c_f,dtype=np.float64),columns = genes)
    V_c_f = V_c_f.fillna(0)
    
    return(X_c_f,X_d,V_c_f,G_f,genes)
    
def predict_network(Xc,Vc,Gf,fg):
    
    Xpinv = np.linalg.pinv(Xc)
    W = np.dot(Xpinv,(Vc + np.dot(Xc,Gf)))
    W = np.nan_to_num(W,nan=0)  # Set to zero the weights that could not be inferred 
    W_d = pd.DataFrame(np.array(W,dtype=np.float64),index=fg,columns=fg)
    
    return(W,W_d)   

def select_TFs(adata,c):
    
    exp_tf = adata.uns['expressed_TF']
    ad_ctf = adata[adata.obs[clustcol]== c,exp_tf]
    ad_ctf.var['cluster'] = 1
    exp_med = np.array(np.mean(ad_ctf.layers['spliced'].todense(),axis=0)).reshape(len(exp_tf),1)

    tf_dat = pd.DataFrame(exp_med,index=exp_tf,columns={'mean_expression'})
    tf_dat['mean_velocity'] = np.abs(np.mean(ad_ctf.layers['velocity'],axis=0)).reshape(len(exp_tf),1)
    tf_dat.fillna(0,inplace=True)
    tf_dat = tf_dat[[not(b1 and b2) for b1,b2 in zip(tf_dat['mean_expression']==0,tf_dat['mean_velocity']==0)]]
    
    tf_ranks = tf_dat.rank()
    tf_ranks['mean'] = np.average(tf_ranks,axis=1,weights=np.array([0.3,0.7]))
    tf_ranks.columns = [col+'_rank'for col in tf_ranks.columns]

    tf_dat = pd.concat([tf_dat,tf_ranks],axis=1).sort_values(by='mean_rank',ascending=False)
    tf_dat['included'] = tf_dat['mean_rank']>=np.quantile(tf_dat['mean_rank'],0.65)


    top_tf = tf_dat.index[tf_dat['included']].tolist()

    return(top_tf)

### Plotting
