In [1]:
%load_ext Cython
%matplotlib inline

import os
import sys
notebook_path = os.path.abspath('')
sources_path = os.path.abspath(os.path.join(notebook_path, '..', 'sources'))
sys.path.insert(0, sources_path)

import torch
import time
import random
import matplotlib.pylab as plt
import numpy as np

In [65]:
from fingerprint import SingleCellFingerprintBase, SingleCellFingerprintDTM

device = torch.device('cuda')
dtype = torch.float

dataset_name = 'pbmc4k'
sc_fingerprint_path = '/home/jupyter/data/10x/pbmc4k_sc_fingerprint.pkl'

zinb_fitter_kwargs = {
    'lr': 0.2,
    'max_iters': 10_000,
    'p_zero_l1_reg': 0.001,
    'outlier_stringency': 5.0,
    'max_zinb_p_zero': 0.8,
    'min_zinb_p_zero': 0.005,
    'min_nb_phi': 0.01,
    'max_nb_phi': 0.95
}

# load fingerprint and instantiate the data-store
sc_fingerprint_base = SingleCellFingerprintBase.load(sc_fingerprint_path)
sc_fingerprint_base = sc_fingerprint_base.filter_genes(
    max_good_turing=1.,
    min_total_gene_expression=1.).sort_genes_by_expression()

# Instantiate the fingerprint datastore
sc_fingerprint_dtm = SingleCellFingerprintDTM(
    sc_fingerprint_base,
    n_gene_groups=100,
    n_pca_features=100,
    zinb_fitter_kwargs=zinb_fitter_kwargs,
    max_estimated_chimera_family_size=0,
    allow_dense_int_ndarray=False,
    device=device,
    dtype=dtype)

Calculating and caching good_turing_estimator_g...
Calculating and caching total_molecules_per_gene_g...
Number of genes failed the maximum Good-Turing criterion: 13872
Number of genes failed the minimum expression criterion: 13872
Number of genes failed both criteria: 13872
Number of retained genes: 19822
Calculating and caching total_molecules_per_gene_g...
Calculating and caching total_molecules_per_gene_g...


NameError: name 'Dict' is not defined

In [68]:
x = sc_fingerprint_dtm._generate_stratified_sample(10, 5, 5)

In [73]:
np.sum(x['cell_sampling_site_scale_factor_array'])

93421086.0

In [74]:
sc_fingerprint_dtm.n_cells * sc_fingerprint_dtm.n_genes

93421086

In [36]:
%%cython
# distutils: language = c++

cimport cython

from cyrandom.cyrandom cimport randint

from cpython cimport bool
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
from cython.operator cimport dereference as deref, preincrement as inc

from libcpp.unordered_set cimport unordered_set as unordered_set
from libcpp.vector cimport vector as vector

import numpy as np
from typing import List

from libc.stdlib cimport rand, RAND_MAX
from libc.stdint cimport int32_t, uint8_t, uint16_t, uint32_t


@cython.boundscheck(False)
@cython.wraparound(False)
cdef class CSRBinaryMatrix:
    cdef uint32_t n_rows
    cdef uint32_t n_cols
    cdef uint32_t* indptr
    cdef uint32_t* indices
    
    def __cinit__(
            self,
            size_t n_rows,
            size_t n_cols,
            size_t indices_sz,
            *args, **kwargs):
        self.n_rows = n_rows
        self.n_cols = n_cols

        # allocate memory
        self.indptr = <uint32_t*> PyMem_Malloc((n_rows + 1) * sizeof(uint32_t))
        self.indices = <uint32_t*> PyMem_Malloc(indices_sz * sizeof(uint32_t))
    
    def __init__(self,
            size_t n_rows,
            size_t n_cols,
            size_t indices_sz,
            uint32_t[:] indptr,
            uint32_t[:] indices,
            bool skip_copy = False,
            bool skip_validation = False):
        
        cdef size_t i, j
        if not skip_validation:
            assert len(indptr) == n_rows + 1, \
                f"The length of indptr ({len(indptr)}) does not match match n_rows + 1 ({n_rows + 1})"
            assert len(indices) == indices_sz, \
                f"The length of indices ({len(indices)}) does not match indices_sz ({indices_sz})"
            assert indptr[0] == 0, \
                "The first entry in indptr must be 0"
            assert indptr[n_rows] == indices_sz, \
                "The last entry in indptr must be equal to the length of indices"        
            for i in range(n_rows):
                assert indptr[i + 1] > indptr[i], \
                    "indptr must be strictly ascending"
                for j in range(indptr[i], indptr[i + 1] - 1):
                    assert indices[j + 1] > indices[j], \
                        "for each row, indices must be unique and sorted in ascending order"
                assert indices[indptr[i + 1] - 1] < n_cols, \
                    f"indices must be in range [0, {n_cols})"

        if not skip_copy:
            for i in range(n_rows + 1):
                self.indptr[i] = indptr[i]
            for i in range(indices_sz):
                self.indices[i] = indices[i]
                
    cpdef uint32_t get_non_zero_cols(self, uint32_t i_row):
        return self.indptr[i_row + 1] - self.indptr[i_row]

    def __invert__(self) -> CSRBinaryMatrix:
        """Returns the bitwise not of the matrix."""
        # allocate memory
        cdef uint32_t inv_indices_sz = self.n_rows * self.n_cols - self.indptr[self.n_rows]
        inv_array = CSRBinaryMatrix(
            n_rows=self.n_rows,
            n_cols=self.n_cols,
            indices_sz=inv_indices_sz,
            indptr=None,
            indices=None,
            skip_copy=True,
            skip_validation=True)

        cdef:
            uint32_t first_index, last_index
            uint32_t n_values, n_compl_values
            uint32_t i_row, i_col, c_index, inv_indices_loc
        
        inv_indices_loc = 0
        inv_array.indptr[0] = 0
        for i_row in range(self.n_rows):
            first_index = self.indptr[i_row]
            last_index = self.indptr[i_row + 1]
            n_values = last_index - first_index

            n_compl_values = self.n_cols - n_values
            c_index = first_index
            c_value = self.indices[c_index]
            for i_col in range(self.n_cols):
                if i_col < c_value:
                    inv_array.indices[inv_indices_loc] = i_col
                    inv_indices_loc += 1
                elif c_index < last_index - 1:
                    c_index += 1
                    c_value = self.indices[c_index]
                else:
                    c_value = self.n_cols
            inv_array.indptr[i_row + 1] = inv_indices_loc

        return inv_array
        
    cpdef vector[uint32_t] draw_from_row_with_replacement(
            self,
            uint32_t i_row,
            uint32_t n_samples):
        """Generates random samples from a given row with replacement."""
        cdef uint32_t first_index = self.indptr[i_row]
        cdef uint32_t n_values = self.indptr[i_row + 1] - first_index
        cdef Py_ssize_t i
        cdef vector[uint32_t] samples_vec
        for i in range(n_samples):
            samples_vec.push_back(self.indices[first_index + randint(0, n_values - 1)])
        return samples_vec

    cpdef unordered_set[uint32_t] draw_from_row_without_replacement(
            self,
            uint32_t i_row,
            uint32_t n_samples) except *:
        """Generates random samples from a given row without replacement via Floyd's algorithm."""
        cdef uint32_t first_index = self.indptr[i_row]
        cdef uint32_t n_values = self.indptr[i_row + 1] - first_index
        assert(n_samples <= n_values)
        
        # Floyd's algorithm for sampling without replacement
        cdef unordered_set[uint32_t] samples_set
        cdef uint32_t i, pos, item
        for i in range(n_values - n_samples, n_values):
            pos = randint(0, i)
            item = self.indices[first_index + pos]
            if samples_set.find(item) != samples_set.end():
                samples_set.insert(self.indices[first_index + i])
            else:
                samples_set.insert(item)
                
        return samples_set

    def __dealloc__(self):
        PyMem_Free(self.indptr)
        PyMem_Free(self.indices)


cdef class SingleCellFingerprintStratifiedSampler:
    # (n_gene_groups, n_genes)
    cdef CSRBinaryMatrix gene_groups_csr
    
    # (n_genes, n_cells)
    cdef CSRBinaryMatrix expressing_cells_csr
    
    # (n_genes, n_cells)
    cdef CSRBinaryMatrix silent_cells_csr

    cdef uint32_t n_genes
    cdef uint32_t n_cells
    cdef uint32_t n_gene_groups

    cdef uint32_t genes_per_gene_group
    cdef uint32_t expressing_cells_per_gene
    cdef uint32_t silent_cells_per_gene
    
    def __cinit__(
            self,
            CSRBinaryMatrix gene_groups_csr,
            CSRBinaryMatrix expressing_cells_csr,
            uint32_t genes_per_gene_group,
            uint32_t expressing_cells_per_gene,
            uint32_t silent_cells_per_gene):
        
        self.n_gene_groups = gene_groups_csr.n_rows
        self.n_genes = expressing_cells_csr.n_rows
        self.n_cells = expressing_cells_csr.n_cols
        assert gene_groups_csr.n_cols == self.n_genes, \
            (f"The column dim of gene_groups_csr ({gene_groups_csr.n_cols}) must "
             f"match the row dim of expressing_cells_csr ({expressing_cells_csr.n_rows})")
        
        self.gene_groups_csr = gene_groups_csr
        self.expressing_cells_csr = expressing_cells_csr
        
        self.silent_cells_csr = ~expressing_cells_csr
        
        self.genes_per_gene_group = genes_per_gene_group
        self.expressing_cells_per_gene = expressing_cells_per_gene
        self.silent_cells_per_gene = silent_cells_per_gene
        
    def draw(self):
        cdef vector[uint32_t] gene_indices_vec
        cdef vector[uint32_t] cell_indices_vec
        cdef vector[double] gene_sampling_site_scale_factor_vec
        cdef vector[double] cell_sampling_site_scale_factor_vec

        cdef uint32_t i_gene_group, i_gene, c_n_genes, c_gene_group_sz
        cdef double c_gene_scale_factor

        cdef uint32_t c_expressing_cells_sz, c_n_expressing_cells
        cdef double c_expressing_cell_scale_factor
                
        cdef uint32_t c_silent_cells_sz, c_n_silent_cells
        cdef double c_silent_cell_scale_factor

        cdef uint32_t c_total_cells_for_gene
        cdef double c_fractionalized_gene_scale_factor
        
        cdef unordered_set[uint32_t] c_gene_indices
        cdef unordered_set[uint32_t].iterator i_gene_it

        cdef unordered_set[uint32_t] c_expressing_cell_indices
        cdef unordered_set[uint32_t].iterator i_expressing_cell_it

        cdef unordered_set[uint32_t] c_silent_cell_indices
        cdef unordered_set[uint32_t].iterator i_silent_cell_it

        cdef size_t i
        
        # select genes
        for i_gene_group in range(self.n_gene_groups):
            
            # number of genes to draw from the gene group
            c_gene_group_sz = self.gene_groups_csr.get_non_zero_cols(i_gene_group)
            c_n_genes = min(self.genes_per_gene_group, c_gene_group_sz)
            c_gene_indices = self.gene_groups_csr.draw_from_row_without_replacement(
                i_gene_group, c_n_genes)
            
            # weight of randomly drawn genes
            c_gene_scale_factor = c_gene_group_sz / c_n_genes

            # select cells
            i_gene_it = c_gene_indices.begin()
            while i_gene_it != c_gene_indices.end():
                i_gene = deref(i_gene_it)
                inc(i_gene_it)
                
                # draw expressing cells from "i_gene"
                c_expressing_cells_sz = self.expressing_cells_csr.get_non_zero_cols(i_gene)
                c_n_expressing_cells = min(self.expressing_cells_per_gene, c_expressing_cells_sz)
                c_expressing_cell_indices = self.expressing_cells_csr.draw_from_row_without_replacement(
                    i_gene, c_n_expressing_cells)
                c_expressing_cell_scale_factor = (
                    c_gene_scale_factor * c_expressing_cells_sz / c_n_expressing_cells)

                # add to containers
                cell_indices_vec.insert(
                    cell_indices_vec.end(),
                    c_expressing_cell_indices.begin(),
                    c_expressing_cell_indices.end())
                
                for i in range(c_n_expressing_cells):
                    cell_sampling_site_scale_factor_vec.push_back(c_expressing_cell_scale_factor)

                # draw silent cells from "i_gene"
                c_silent_cells_sz = self.silent_cells_csr.get_non_zero_cols(i_gene)
                c_n_silent_cells = min(self.silent_cells_per_gene, c_silent_cells_sz)
                c_silent_cell_indices = self.silent_cells_csr.draw_from_row_without_replacement(
                    i_gene, c_n_silent_cells)
                c_silent_cell_scale_factor = (
                    c_gene_scale_factor * c_silent_cells_sz / c_n_silent_cells)

                # add to containers                
                cell_indices_vec.insert(
                    cell_indices_vec.end(),
                    c_silent_cell_indices.begin(),
                    c_silent_cell_indices.end())
                
                for i in range(c_n_silent_cells):
                    cell_sampling_site_scale_factor_vec.push_back(c_silent_cell_scale_factor)

                # gene sampling site effective ("fractionalized") scale factor
                c_total_cells_for_gene = c_n_expressing_cells + c_n_silent_cells
                c_fractionalized_gene_scale_factor = c_gene_scale_factor / c_total_cells_for_gene
                
                for i in range(c_total_cells_for_gene):
                    gene_sampling_site_scale_factor_vec.push_back(c_fractionalized_gene_scale_factor)
                    gene_indices_vec.push_back(i_gene)

        return {'gene_index_array': gene_indices_vec,
                'cell_index_array': cell_indices_vec,
                'cell_sampling_site_scale_factor_array': cell_sampling_site_scale_factor_vec,
                'gene_sampling_site_scale_factor_array': gene_sampling_site_scale_factor_vec}

In [37]:
# gene_groups_indptr_list = []
# gene_groups_indices_list = []
# gene_groups_indptr_list.append(0)
# for i_gene_group in range(sc_fingerprint_dtm.n_gene_groups):
#     gene_groups_indices_list += sc_fingerprint_dtm.gene_groups_dict[i_gene_group]
#     gene_groups_indptr_list.append(len(gene_groups_indices_list))
# gene_groups_indptr_ndarray = np.asarray(gene_groups_indptr_list, dtype=np.uint32)
# gene_groups_indices_ndarray = np.asarray(gene_groups_indices_list, dtype=np.uint32)

In [88]:
n_gene_groups = 100
n_genes = 50
n_cells = 1000
max_genes_per_gene_group = 5
max_expressing_cells = 100

np.random.seed(1234)

# random gene groups
indptr = np.cumsum(np.random.randint(
    1, max_genes_per_gene_group,
    dtype=np.uint32,
    size=n_gene_groups + 1)).astype(np.uint32)
indptr[0] = 0
indices_list = []
for i in range(n_gene_groups):
    sz = indptr[i + 1] - indptr[i]
    genes = sorted(np.random.choice(np.arange(0, n_genes), sz, replace=False))
    indices_list += genes
indices = np.asarray(indices_list, dtype=np.uint32)
gene_groups = CSRBinaryMatrix(
    n_rows=n_gene_groups,
    n_cols=n_genes,
    indices_sz=len(indices),
    indptr=indptr,
    indices=indices)

# random gene expression
indptr = np.cumsum(np.random.randint(
    1, max_expressing_cells,
    dtype=np.uint32,
    size=n_genes + 1)).astype(np.uint32)
indptr[0] = 0
indices_list = []
for i in range(n_genes):
    sz = indptr[i + 1] - indptr[i]
    cells = sorted(np.random.choice(np.arange(0, n_cells), sz, replace=False))
    indices_list += cells
indices = np.asarray(indices_list, dtype=np.uint32)

expressing_cells = CSRBinaryMatrix(
    n_rows=n_genes,
    n_cols=n_cells,
    indices_sz=len(indices),
    indptr=indptr,
    indices=indices)

In [103]:
indptr = [0]
indices = []
for i in range(sc_fingerprint_dtm.n_gene_groups):
    indices += sorted(sc_fingerprint_dtm.gene_groups_dict[i])
    indptr.append(len(indices))

gene_groups_csr = CSRBinaryMatrix(
    n_rows=sc_fingerprint_dtm.n_gene_groups,
    n_cols=sc_fingerprint_dtm.n_genes,
    indices_sz=len(indices),
    indptr=np.asarray(indptr, dtype=np.uint32),
    indices=np.asarray(indices, dtype=np.uint32))

indptr = [0]
indices = []
for i in range(sc_fingerprint_dtm.n_genes):
    indices += sorted(sc_fingerprint_dtm.get_expressing_cell_indices(i))
    indptr.append(len(indices))

expressing_cells_csr = CSRBinaryMatrix(
    n_rows=sc_fingerprint_dtm.n_genes,
    n_cols=sc_fingerprint_dtm.n_cells,
    indices_sz=len(indices),
    indptr=np.asarray(indptr, dtype=np.uint32),
    indices=np.asarray(indices, dtype=np.uint32))

In [104]:
sampler = SingleCellFingerprintStratifiedSampler(
    gene_groups_csr=gene_groups_csr,
    expressing_cells_csr=expressing_cells_csr,
    genes_per_gene_group=10,
    expressing_cells_per_gene=25,
    expressing_cells_per_gene=5)

SyntaxError: positional argument follows keyword argument (<ipython-input-104-50c9b227e4a8>, line 4)

In [94]:
for _ in range(50):
    sampler.draw()

In [95]:
x = sampler.draw()

In [96]:
sum(x['gene_sampling_site_scale_factor_array'])

250.0

In [97]:
250*1000

250000

In [98]:
x

{'gene_index_array': [8,
  8,
  37,
  37,
  31,
  31,
  34,
  34,
  5,
  5,
  44,
  44,
  31,
  31,
  4,
  4,
  30,
  30,
  2,
  2,
  12,
  12,
  37,
  37,
  40,
  40,
  8,
  8,
  37,
  37,
  44,
  44,
  37,
  37,
  21,
  21,
  14,
  14,
  47,
  47,
  5,
  5,
  45,
  45,
  15,
  15,
  8,
  8,
  8,
  8,
  27,
  27,
  26,
  26,
  2,
  2,
  42,
  42,
  25,
  25,
  11,
  11,
  36,
  36,
  15,
  15,
  0,
  0,
  10,
  10,
  42,
  42,
  26,
  26,
  13,
  13,
  22,
  22,
  16,
  16,
  15,
  15,
  48,
  48,
  29,
  29,
  43,
  43,
  11,
  11,
  21,
  21,
  32,
  32,
  1,
  1,
  2,
  2,
  11,
  11,
  33,
  33,
  44,
  44,
  15,
  15,
  43,
  43,
  32,
  32,
  15,
  15,
  1,
  1,
  43,
  43,
  43,
  43,
  18,
  18,
  23,
  23,
  10,
  10,
  44,
  44,
  9,
  9,
  44,
  44,
  11,
  11,
  29,
  29,
  44,
  44,
  46,
  46,
  12,
  12,
  2,
  2,
  38,
  38,
  38,
  38,
  27,
  27,
  10,
  10,
  22,
  22,
  17,
  17,
  25,
  25,
  39,
  39,
  43,
  43,
  9,
  9,
  11,
  11,
  27,
  27,
  42,
  42,
  18