Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding ability to add multiple values of k #44

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
[![Build Status](https://travis-ci.org/chriscainx/mnnpy.svg?branch=master)](https://travis-ci.org/chriscainx/mnnpy) [![Downloads](http://pepy.tech/badge/mnnpy)](http://pepy.tech/count/mnnpy)
# mnnpy - MNN-correct in python!
# M-mnnpy - MNN-correct in python + Marioni's Lab version as a bonus!

An implementation of MNN correct in python featuring low memory usage, full multicore support and compatibility with the [scanpy](https://github.com/theislab/scanpy) framework.
An implementation of MNN correct in python featuring low memory usage, full multicore support and an additional implementation of MNN from the Marioni lab (modified or marioni MNN)

just use the `marioniCorrect()` function instead

you can still use the regular MNN with `mnn_correct()`

## Install

Mnnpy is available on PyPI. You can install with `pip install mmnnpy`.

If you want the developing version, do:
```
git clone https://github.com/jkobject/mnnpy.git
cd mnpy
pip install .
```

## Below is the readme from the original MNN tool I forked (it is still usable on this version and even more robust!)

Batch effect correction by matching mutual nearest neighbors [(Haghverdi et al, 2018)](https://www.nature.com/articles/nbt.4091) has been implemented as a function 'mnnCorrect' in the R package [scran](https://bioconductor.org/packages/release/bioc/html/scran.html). Sadly it's extremely slow for big datasets and doesn't make full use of the parallel architecture of modern CPUs.

Expand Down Expand Up @@ -38,23 +54,12 @@ Finishes correcting ~50000 cells/19 batches * ~30000 genes in ~12h on a 16 core
- Compatible with scanpy
- Full verbosity

## Install

Mnnpy is available on PyPI. You can install with `pip install mnnpy`.

If you want the developing version, do:
```
git clone https://github.com/chriscainx/mnnpy.git
cd mnnpy
pip install .
```

## Usage

Mnnpy takes matrices or AnnData objects. For example:
```python
import scanpy.api as sc
import mnnpy
import mmnnpy

sample1 = sc.read("Sample1.h5ad")
sample2 = sc.read("Sample2.h5ad")
Expand Down
1 change: 1 addition & 0 deletions mnnpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .mnn import mnn_correct
from .mnn import marioniCorrect
Binary file added mnnpy/_utils.cpython-37m-x86_64-linux-gnu.so
Binary file not shown.
210 changes: 205 additions & 5 deletions mnnpy/mnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,209 @@
import numpy as np
from anndata import AnnData
from pandas import DataFrame
from .utils import transform_input_data, find_mutual_nn, compute_correction
import pandas as pd
from .utils import transform_input_data, compute_correction
from .utils import svd_internal, find_shared_subspace, get_bio_span, subtract_bio
from .utils import adjust_shift_variance
from .utils import adjust_shift_variance, l2_norm, scale_rows
from scipy.sparse import issparse
from scipy.spatial import cKDTree

###############################################################################
##################### Marioni MNN Alignment ###################################
###############################################################################

def marioniCorrect(ref_mat, targ_mat, k1=20, k2=20, fk=5, ndist=3, var_index=None, var_subset=None,
cosine_norm=True, n_jobs=None):
"""marioniCorrect is a function that corrects for batch effects using the Marioni method.

Args:
ref_mat (pd.Dataframe): matrix of samples by genes of cPC corrected data that serves as the reference data in the MNN alignment.
In the standard Celligner pipeline this the cell line data.
targ_mat matrix of samples by genes of cPC corrected data that is corrected in the MNN alignment and projected onto the reference data.
In the standard Celligner pipeline this the tumor data.
mnn_kwargs (dict): args to mnnCorrect
k1 (int): number of nearest neighbors to use for the first batch
k2 (int): number of nearest neighbors to use for the second batch
fk (int): number of nearest neighbors to use for the first batch
ndist (int): number of nearest neighbors to use for the first batch
var_index (pd.Series): index of variables to use for the first batch
var_subset (list): list of variables to use for the first batch
n_jobs (int): number of jobs to use for parallelization

Returns:
pd.Dataframe: corrected dataframe
"""
if n_jobs is None:
n_jobs = cpu_count()
n_cols = ref_mat.shape[1]
if len(var_index) != n_cols:
raise ValueError('The number of vars is not equal to the length of var_index.')
if targ_mat.shape[1] != n_cols:
raise ValueError('The input matrices have inconsistent number of columns.')

if var_subset is not None:
subref_mat = ref_mat.loc[:, var_subset].values
subtarg_mat = targ_mat.loc[:, var_subset].values
else:
subref_mat = ref_mat.values
subtarg_mat = targ_mat.values

if cosine_norm:
print('Performing cosine normalization...')
in_batches = _cosineNormalization(subref_mat, subtarg_mat,
cos_norm_in=True, cos_norm_out=True, n_jobs=n_jobs)
subref_mat, subtarg_mat = in_batches
del in_batches
#in_batches = _cosineNormalization(ref_mat, targ_mat,
# cos_norm_in=True, cos_norm_out=True, n_jobs=n_jobs)
#ref_mat, targ_mat = in_batches
#del in_batches

print(' Looking for MNNs...')
mnn_pairs = findMutualNN(data1=subref_mat, data2=subtarg_mat, k1=k1, k2=k2, n_jobs=n_jobs)
print(' Found '+str(len(mnn_pairs))+' mutual nearest neighbors.')
mnn_ref, mnn_targ = np.array(mnn_pairs).T

# TODO: this block shouldn't be usefull
idx=np.argsort(mnn_ref)
mnn_ref=mnn_ref[idx]
mnn_targ=mnn_targ[idx]

# compute the overall batch vector
corvec, _ = _averageCorrection(ref_mat.values, mnn_ref, targ_mat.values, mnn_targ)
overall_batch = corvec.mean(axis=0)

# remove variation along the overall batch vector
ref_mat = _squashOnBatchDirection(ref_mat.values, overall_batch)
targ = _squashOnBatchDirection(targ_mat.values, overall_batch)
# recompute correction vectors and apply them
re_ave_out, npairs = _averageCorrection(ref_mat, mnn_ref, targ, mnn_targ)
del subref_mat, subtarg_mat, ref_mat
# TODO: why cKDTRee results depend on how we order the input matrix' datapoints??
distances, index = cKDTree(np.take(targ, np.sort(npairs), 0)[:,var_subset]).query(
x=targ[:, var_subset],
k=min(fk, len(npairs)),
n_jobs=n_jobs)
targ_mat = pd.DataFrame(data=targ, columns=targ_mat.columns, index=targ_mat.index)
targ_mat += _computeTricubeWeightedAvg(re_ave_out[np.argsort(npairs)], index, distances, ndist=ndist)
return targ_mat, mnn_pairs

#@jit((float32[:, :], float32[:, :], int8, int8, int8))
def findMutualNN(data1, data2, k1, k2, n_jobs):
"""findMutualNN finds the mutual nearest neighbors between two sets of data.

Args:
data1 ([type]): [description]
data2 ([type]): [description]
k1 ([type]): [description]
k2 ([type]): [description]
n_jobs ([type]): [description]

Returns:
[type]: [description]
"""
k_index_1 = cKDTree(data1).query(x=data2, k=k1, n_jobs=n_jobs)[1]
k_index_2 = cKDTree(data2).query(x=data1, k=k2, n_jobs=n_jobs)[1]
mutuale = []
for index_2, val in enumerate(k_index_1):
for index_1 in val:
if index_2 in k_index_2[index_1]:
mutuale.append((index_1, index_2))
return mutuale

def _cosineNormalization(*datas, cos_norm_in, cos_norm_out, n_jobs):
"""_cosineNormalization transforms input data to be centered and normalized.

Args:
cos_norm_in ([type]): [description]
cos_norm_out ([type]): [description]
n_jobs ([type]): [description]

Returns:
[type]: [description]
"""
datas = [data.toarray().astype(np.float32) if issparse(data) else data.astype(np.float32) for data in datas]
with Pool(n_jobs) as p_n:
in_scaling = p_n.map(l2_norm, datas)
in_scaling = [scaling[:, None] for scaling in in_scaling]
if cos_norm_in:
with Pool(n_jobs) as p_n:
datas = p_n.starmap(scale_rows, zip(datas, in_scaling))
return datas

def _averageCorrection(refdata, mnn1, curdata, mnn2):
"""_averageCorrection computes correction vectors for each MNN pair, and then averages them for each MNN-involved cell in the second batch.

Args:
refdata (pandas.DataFrame): matrix of samples by genes of cPC corrected data that serves as the reference data in the MNN alignment.
mnn1 (list): mnn1 pairs
curdata (pandas.DataFrame): matrix of samples by genes of cPC corrected data that is corrected in the MNN alignment and projected onto the reference data.
mnn2 (list): mnn2 pairs

Returns:
dict: correction vector and pairs
"""
npairs = pd.Series(mnn2).value_counts()
corvec = np.take(refdata, mnn1, 0) - np.take(curdata, mnn2, 0)
cor = np.zeros((len(npairs),corvec.shape[1]))
mnn2 = np.array(mnn2)
#mnn2_sort = np.sort(mnn_targ)
for i, v in enumerate(set(mnn2)):
cor[i] = corvec[mnn2==v].sum(0)/npairs[v]
return cor, list(set(mnn2))

def _squashOnBatchDirection(mat, batch_vec):
"""_squashOnBatchDirection - Projecting along the batch vector, and shifting all samples to the center within each batch.

Args:
mat (pandas.DataFrame): matrix of samples by genes
batch_vec (pandas.Series): batch vector

Returns:
pandas.DataFrame: corrected matrix
"""
batch_vec = batch_vec/np.sqrt(np.sum(batch_vec**2))
batch_loc = np.dot(mat, batch_vec)
mat = mat + np.outer(np.mean(batch_loc) - batch_loc, batch_vec)
return mat

def _computeTricubeWeightedAvg(vals, indices, distances, bandwidth=None, ndist=3):
"""_computeTricubeWeightedAvg - Centralized function to compute tricube averages.

Args:
vals (pandas.DataFrame): correction vector
indices (pandas.DataFrame): nxk matrix for the nearest neighbor indice
distances (pandas.DataFrame): nxk matrix for the nearest neighbor Euclidea distances
bandwidth (float): Is set at 'ndist' times the median distance, if not specified.
ndist (int, optional): By default is MNN_NDIST.

Returns:
[type]: [description]
"""
if bandwidth is None:
middle = int(np.floor(indices.shape[1]/2))
mid_dist = distances[:,middle]
bandwidth = mid_dist * ndist
bandwidth = np.maximum(1e-8, bandwidth)

rel_dist = distances.T/bandwidth
# don't use pmin(), as this destroys dimensions.
rel_dist[rel_dist > 1] = 1
tricube = (1 - rel_dist**3)**3
weight = tricube/np.sum(tricube, axis=0)
del rel_dist, tricube, bandwidth

output = np.zeros((indices.shape[0], vals.shape[1]))
for kdx in range(indices.shape[1]):
output += np.einsum("ij...,i...->ij...", vals[indices[:,kdx]], weight[kdx])
return output

###############################################################################
##################### Regular MNN Alignment ###################################
###############################################################################

def mnn_correct(*datas, var_index=None, var_subset=None, batch_key='batch', index_unique='-',
batch_categories=None, k=20, sigma=1., cos_norm_in=True, cos_norm_out=True,
batch_categories=None, k1=20, k2=20, sigma=1., cos_norm_in=True, cos_norm_out=True,
svd_dim=None, var_adj=True, compute_angle=False, mnn_order=None, svd_mode='rsvd',
do_concatenate=True, save_raw=False, n_jobs=None, **kwargs):
"""
Expand Down Expand Up @@ -120,7 +316,7 @@ def mnn_correct(*datas, var_index=None, var_subset=None, batch_key='batch', inde
if var_subset is not None and set(adata_vars) == set(var_subset):
var_subset = None
corrected = mnn_correct(*(adata.X for adata in datas), var_index=adata_vars,
var_subset=var_subset, k=k, sigma=sigma, cos_norm_in=cos_norm_in,
var_subset=var_subset, k1=k1, k2=k2, sigma=sigma, cos_norm_in=cos_norm_in,
cos_norm_out=cos_norm_out, svd_dim=svd_dim, var_adj=var_adj,
compute_angle=compute_angle, mnn_order=mnn_order,
svd_mode=svd_mode, do_concatenate=do_concatenate, **kwargs)
Expand Down Expand Up @@ -175,8 +371,12 @@ def mnn_correct(*datas, var_index=None, var_subset=None, batch_key='batch', inde
if not same_set:
new_batch_out = out_batches[target]
print(' Looking for MNNs...')
mnn_ref, mnn_new = find_mutual_nn(data1=ref_batch_in, data2=new_batch_in, k1=k, k2=k,
mnn = findMutualNN(data1=ref_batch_in, data2=new_batch_in, k1=k1, k2=k2,
n_jobs=n_jobs)
val = np.array(mnn)
mnn_ref = val[:,0]
mnn_new = val[:,1]
print('found ' + str(len(mnn_ref)) + " mnns..")
print(' Computing correction vectors...')
correction_in = compute_correction(ref_batch_in, new_batch_in, mnn_ref, mnn_new,
new_batch_in, sigma)
Expand Down
16 changes: 7 additions & 9 deletions mnnpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,19 @@ def transform_input_data(datas, cos_norm_in, cos_norm_out, var_index, var_subset
return in_batches, out_batches, var_sub_index, same_set


@jit((float32[:, :], float32[:, :], int8, int8, int8))
#@jit((float32[:, :], float32[:, :], int8, int8, int8))
def find_mutual_nn(data1, data2, k1, k2, n_jobs):
k_index_1 = cKDTree(data1).query(x=data2, k=k1, n_jobs=n_jobs)[1]
k_index_2 = cKDTree(data2).query(x=data1, k=k2, n_jobs=n_jobs)[1]
mutual_1 = []
mutual_2 = []
for index_2 in range(data2.shape[0]):
for index_1 in k_index_1[index_2]:
mutuale = []
for index_2, val in enumerate(k_index_1):
for index_1 in val:
if index_2 in k_index_2[index_1]:
mutual_1.append(index_1)
mutual_2.append(index_2)
return mutual_1, mutual_2
mutuale.append((index_1, index_2))
return mutuale


@jit(float32[:, :](float32[:, :], float32[:, :], int32[:], int32[:], float32[:, :], float32))
#@jit(float32[:, :](float32[:, :], float32[:, :], int32[:], int32[:], float32[:, :], float32))
def compute_correction(data1, data2, mnn1, mnn2, data2_or_raw2, sigma):
vect = data1[mnn1] - data2[mnn2]
mnn_index, mnn_count = np.unique(mnn2, return_counts=True)
Expand Down