In [3]:
%load_ext Cython
%matplotlib inline

import os
import sys
notebook_path = os.path.abspath('')
sources_path = os.path.abspath(os.path.join(notebook_path, '..', 'sources'))
sys.path.insert(0, sources_path)

import torch
import time
import random
import matplotlib.pylab as plt
import numpy as np

In [4]:
# from fingerprint import SingleCellFingerprintBase, SingleCellFingerprintDTM

# device = torch.device('cuda')
# dtype = torch.float

# dataset_name = 'pbmc4k'
# sc_fingerprint_path = '/home/jupyter/data/10x/pbmc4k_sc_fingerprint.pkl'

# zinb_fitter_kwargs = {
#     'lr': 0.2,
#     'max_iters': 10_000,
#     'p_zero_l1_reg': 0.001,
#     'outlier_stringency': 5.0,
#     'max_zinb_p_zero': 0.8,
#     'min_zinb_p_zero': 0.005,
#     'min_nb_phi': 0.01,
#     'max_nb_phi': 0.95
# }

# # load fingerprint and instantiate the data-store
# sc_fingerprint_base = SingleCellFingerprintBase.load(sc_fingerprint_path)
# sc_fingerprint_base = sc_fingerprint_base.filter_genes(
#     max_good_turing=1.,
#     min_total_gene_expression=1.).sort_genes_by_expression()

# # Instantiate the fingerprint datastore
# sc_fingerprint_dtm = SingleCellFingerprintDTM(
#     sc_fingerprint_base,
#     n_gene_groups=100,
#     n_pca_features=100,
#     zinb_fitter_kwargs=zinb_fitter_kwargs,
#     max_estimated_chimera_family_size=0,
#     allow_dense_int_ndarray=False,
#     device=device,
#     dtype=dtype)


# def random_choice(a: List[int], size: int) -> List[int]:
#     return random.sample(a, size)

# def _generate_stratified_sample(
#         mb_genes_per_gene_group: int,
#         mb_expressing_cells_per_gene: int,
#         mb_silent_cells_per_gene: int) -> Dict[str, np.ndarray]:
#     mb_cell_indices_per_gene = []
#     mb_cell_scale_factors_per_gene = []
#     mb_effective_gene_scale_factors_per_cell = []
#     mb_gene_indices_per_cell = []

#     # select genes
#     for i_gene_group in range(self.n_gene_groups):
#         gene_size = min(mb_genes_per_gene_group, len(self.gene_groups_dict[i_gene_group]))
#         gene_indices = random_choice(self.gene_groups_dict[i_gene_group], gene_size)
#         gene_scale_factor = len(self.gene_groups_dict[i_gene_group]) / gene_size

#         # select cells
#         for gene_index in gene_indices:
#             # sample from expressing cells
#             size_expressing = min(mb_expressing_cells_per_gene, self.n_expressing_cells_per_gene[gene_index])
#             expressing_cell_indices = random_choice(
#                 self.get_expressing_cell_indices(gene_index), size_expressing)
#             expressing_scale_factor = gene_scale_factor * (
#                     self.n_expressing_cells_per_gene[gene_index] / size_expressing)

#             # sample from silent cells
#             size_silent = min(mb_silent_cells_per_gene, self.n_silent_cells_per_gene[gene_index])
#             silent_cell_indices = random_choice(self.get_silent_cell_indices(gene_index), size_silent)
#             silent_scale_factor = gene_scale_factor * (
#                     self.n_silent_cells_per_gene[gene_index] / size_silent)

#             mb_cell_indices_per_gene.append(np.asarray(expressing_cell_indices))
#             mb_cell_indices_per_gene.append(np.asarray(silent_cell_indices))
#             mb_cell_scale_factors_per_gene.append(expressing_scale_factor * np.ones((size_expressing,)))
#             mb_cell_scale_factors_per_gene.append(silent_scale_factor * np.ones((size_silent,)))

#             # the effective scale factor for a collapsed gene sampling site is scaled down by the number of cells
#             total_cells_per_gene = size_expressing + size_silent
#             effective_gene_scale_factor = gene_scale_factor / total_cells_per_gene
#             mb_effective_gene_scale_factors_per_cell.append(
#                 effective_gene_scale_factor * np.ones((total_cells_per_gene,)))
#             mb_gene_indices_per_cell.append(gene_index * np.ones((total_cells_per_gene,), dtype=np.int))

#     gene_index_array = np.concatenate(mb_gene_indices_per_cell)
#     cell_index_array = np.concatenate(mb_cell_indices_per_gene)
#     cell_sampling_site_scale_factor_array = np.concatenate(mb_cell_scale_factors_per_gene)
#     gene_sampling_site_scale_factor_array = np.concatenate(mb_effective_gene_scale_factors_per_cell)

#     return {'gene_index_array': gene_index_array,
#             'cell_index_array': cell_index_array,
#             'cell_sampling_site_scale_factor_array': cell_sampling_site_scale_factor_array,
#             'gene_sampling_site_scale_factor_array': gene_sampling_site_scale_factor_array}

In [None]:
%%cython

cimport cython
from cyrandom.cyrandom cimport randint
from cpython cimport bool
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
import numpy as np

from libc.stdlib cimport rand, RAND_MAX
from libc.stdint cimport int32_t, uint8_t, uint16_t, uint32_t

cdef class RaggedBinary2DArray:
    cdef uint32_t n_rows
    cdef uint32_t n_cols
    cdef uint32_t* indptr
    cdef uint32_t* indices
    
    def __cinit__(
            self,
            size_t n_rows,
            size_t n_cols,
            size_t indices_sz,
            *args, **kwargs):
        self.n_rows = n_rows
        self.n_cols = n_cols

        # allocate memory
        self.indptr = <uint32_t*> PyMem_Malloc((n_rows + 1) * sizeof(uint32_t))
        self.indices = <uint32_t*> PyMem_Malloc(indices_sz * sizeof(uint32_t))
    
    def __init__(self,
            size_t n_rows,
            size_t n_cols,
            size_t indices_sz,
            uint32_t[:] indptr,
            uint32_t[:] indices,
            bool skip_copy = False,
            bool skip_validation = False):
        
        cdef size_t i, j
        if not skip_validation:
            assert len(indptr) == n_rows + 1, \
                f"The length of indptr ({len(indptr)}) does not match match n_rows + 1 ({n_rows + 1})"
            assert len(indices) == indices_sz, \
                f"The length of indices ({len(indices)}) does not match indices_sz ({indices_sz})"
            assert indptr[0] == 0, \
                "The first entry in indptr must be 0"
            assert indptr[n_rows] == indices_sz, \
                "The last entry in indptr must be equal to the length of indices"        
            for i in range(n_rows):
                assert indptr[i + 1] > indptr[i], \
                    "indptr must be strictly ascending"
                for j in range(indptr[i], indptr[i + 1] - 1):
                    assert indices[j + 1] > indices[j], \
                        "for each row, indices must be unique and sorted in ascending order"
                assert indices[indptr[i + 1] - 1] < n_cols, \
                    f"indices must be in range [0, {n_cols})"

        if not skip_copy:
            for i in range(n_rows + 1):
                self.indptr[i] = indptr[i]
            for i in range(indices_sz):
                self.indices[i] = indices[i]

    def __invert__(self) -> RaggedBinary2DArray:
        cdef uint32_t inv_indices_sz = self.n_rows * self.n_cols - self.indptr[self.n_rows]
        
        # allocate memory
        inv_array = RaggedBinary2DArray(
            n_rows=self.n_rows,
            n_cols=self.n_cols,
            indices_sz=inv_indices_sz,
            indptr=None,
            indices=None,
            skip_copy=True,
            skip_validation=True)

        cdef:
            uint32_t first_index, last_index
            uint32_t n_values, n_compl_values
            uint32_t i_row, i_col, c_index, inv_indices_loc
        
        inv_indices_loc = 0
        inv_array.indptr[0] = 0
        for i_row in range(self.n_rows):
            first_index = self.indptr[i_row]
            last_index = self.indptr[i_row + 1]
            n_values = last_index - first_index

            n_compl_values = self.n_cols - n_values
            c_index = first_index
            c_value = self.indices[c_index]
            for i_col in range(self.n_cols):
                if i_col < c_value:
                    inv_array.indices[inv_indices_loc] = i_col
                    inv_indices_loc += 1
                elif c_index < last_index - 1:
                    c_index += 1
                    c_value = self.indices[c_index]
                else:
                    c_value = self.n_cols
            inv_array.indptr[i_row + 1] = inv_indices_loc

        return inv_array
        
    def __dealloc__(self):
        PyMem_Free(self.indptr)
        PyMem_Free(self.indices)
    

@cython.boundscheck(False)
@cython.wraparound(False)
cdef void generate_random_samples_from_row(
        RaggedBinary2DArray ragged,
        uint32_t i_row,
        uint32_t n_samples,
        uint32_t* out) nogil:
    """Generates random samples from a row of a ragged 2D binary array."""
    cdef uint32_t first_index = ragged.indptr[i_row]
    cdef uint32_t n_values = ragged.indptr[i_row + 1] - first_index
    cdef Py_ssize_t i
    for i in range(n_samples):
        out[i] = ragged.indices[first_index + randint(0, n_values - 1)]


@cython.boundscheck(False)
@cython.wraparound(False)
cdef void generate_random_samples_from_row_compl(
        RaggedBinary2DArray ragged,
        uint32_t i_row,
        uint32_t n_samples,
        uint32_t* out):
    """Generates random samples from the complement of the row of a ragged 2D binary array."""

    cdef uint32_t first_index = ragged.indptr[i_row]
    cdef uint32_t last_index = ragged.indptr[i_row + 1]
    cdef uint32_t n_values = last_index - first_index
    
    # find complementary values
    cdef uint32_t n_compl_values = ragged.n_cols - n_values
    cdef uint32_t* compl_values = <uint32_t*> PyMem_Malloc(n_compl_values * sizeof(uint32_t))
    cdef uint32_t c_index = first_index
    cdef uint32_t c_value = ragged.indices[c_index]
    cdef uint32_t j = 0
    cdef uint32_t i_col
    for i_col in range(ragged.n_cols):
        if i_col < c_value:
            compl_values[j] = i_col
            j += 1
        elif c_index < last_index - 1:
            c_index += 1
            c_value = ragged.indices[c_index]
        else:
            c_value = ragged.n_cols
            
    # sample from complementary values
    cdef Py_ssize_t i
    for i in range(n_samples):
        out[i] = compl_values[randint(0, n_compl_values - 1)]
    PyMem_Free(compl_values)


def generate_random_samples_from_row_test(
        RaggedBinary2DArray ragged,
        uint32_t i_row,
        uint32_t n_samples):

    cdef uint32_t* out = <uint32_t*> PyMem_Malloc(n_samples * sizeof(uint32_t))
    generate_random_samples_from_row(ragged, i_row, n_samples, out)
    cdef list out_list = [out[i] for i in range(n_samples)]
    PyMem_Free(out)
    return out_list


cdef class SingleCellFingerprintStratifiedSampler:
    # (gene groups, gene indices)
    cdef RaggedBinary2DArray gene_groups
    
    # (gene indices, expressing cells)
    cdef RaggedBinary2DArray expressing_cells
    
    # (gene indices, silent cells)
    cdef RaggedBinary2DArray silent_cells

    def __cinit__(
            self,
            RaggedBinary2DArray gene_groups,
            RaggedBinary2DArray expressing_cells):
        
        self.gene_groups = gene_groups
        self.expressing_cells = expressing_cells
        self.silent_cells = ~expressing_cells
        
    def foo(self):
        pass
    

In [6]:
# gene_groups_indptr_list = []
# gene_groups_indices_list = []
# gene_groups_indptr_list.append(0)
# for i_gene_group in range(sc_fingerprint_dtm.n_gene_groups):
#     gene_groups_indices_list += sc_fingerprint_dtm.gene_groups_dict[i_gene_group]
#     gene_groups_indptr_list.append(len(gene_groups_indices_list))
# gene_groups_indptr_ndarray = np.asarray(gene_groups_indptr_list, dtype=np.uint32)
# gene_groups_indices_ndarray = np.asarray(gene_groups_indices_list, dtype=np.uint32)

In [7]:
np.random.seed(1234)

n_gene_groups = 100
n_genes = 20_000
indptr = np.cumsum(np.random.randint(1, 3, dtype=np.uint32, size=n_gene_groups + 1)).astype(np.uint32)
indptr[0] = 0
indices_list = []
for i in range(n_gene_groups):
    sz = indptr[i + 1] - indptr[i]
    genes = sorted(np.random.choice(np.arange(0, n_genes), sz, replace=False))
    indices_list += genes
indices = np.asarray(indices_list, dtype=np.uint32)

gene_groups = RaggedBinary2DArray(
    n_rows=n_gene_groups,
    n_cols=n_genes,
    indices_sz=len(indices),
    indptr=indptr,
    indices=indices)

In [8]:
from collections import Counter

In [9]:
gene_groups

<_cython_magic_7c6f2d6ef13b7a08292402ee6caf2949.RaggedBinary2DArray at 0x7fca604674b8>

In [10]:
inv = ~gene_groups

In [11]:
Counter(generate_random_samples_from_row_test(gene_groups, 5, 50000))

Counter({13088: 50000})

In [14]:
_ = Counter(generate_random_samples_from_row_test(inv, 4, 500000))

In [21]:
for _ in range(50000):
    generate_random_samples_from_row_compl_test(gene_groups, 5, 50)