Skip to content

Commit

Permalink
Merge pull request #49 from klarman-cell-observatory/boli
Browse files Browse the repository at this point in the history
Use cython fused types for sparse matrix outputs
  • Loading branch information
bli25 committed Dec 19, 2020
2 parents 673512a + f7a6154 commit 405769d
Showing 1 changed file with 23 additions and 40 deletions.
63 changes: 23 additions & 40 deletions ext_modules/io_funcs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@ cimport cython

ctypedef unsigned char uchar

ctypedef fused data_type:
int
long
float
double

ctypedef fused indices_type:
int
long

ctypedef fused indptr_type:
int
long



cdef const char* header_real = b"%%MatrixMarket matrix coordinate real general"
cdef const char* header_int = b"%%MatrixMarket matrix coordinate integer general"
Expand Down Expand Up @@ -76,42 +91,28 @@ cpdef tuple read_mtx(char* mtx_file):

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void write_mtx(char* mtx_file, object data, int[:] indices, int[:] indptr, int M, int N, int precision = 2):
cpdef void write_mtx(char* mtx_file, data_type[:] data, indices_type[:] indices, indptr_type[:] indptr, int M, int N, int precision = 2):
""" Input is csr_matrix internal representation, cell by gene; Output will be gene by cell
"""
cdef FILE* fo = fopen(mtx_file, "w")
cdef str fmt_str = ""
cdef char is_real = 0

cdef int[:] data_int
cdef float[:] data_float

if data.dtype.kind == 'f':
if (data_type is float) or (data_type is double):
fprintf(fo, "%s\n", header_real)
fmt_str = f"%d %d %.{precision}f\n"
is_real = 1
data_float = data
elif data.dtype.kind == 'i':
else:
fprintf(fo, "%s\n", header_int)
fmt_str = "%d %d %d\n"
is_real = 0
data_int = data
else:
raise ValueError(f"Detected unknown dtype: {data.dtype}!")

cdef const char* fmt = fmt_str
cdef Py_ssize_t data_size = data.size
cdef Py_ssize_t i, j

fprintf(fo, "%s\n", metadata)
fprintf(fo, "%d %d %zd\n", N, M, data_size)
fprintf(fo, "%d %d %zd\n", N, M, <long>data.size)

for i in range(M):
for j in range(indptr[i], indptr[i + 1]):
if is_real:
fprintf(fo, fmt, indices[j] + 1, i + 1, data_float[j])
else:
fprintf(fo, fmt, indices[j] + 1, i + 1, data_int[j])
fprintf(fo, fmt, indices[j] + 1, i + 1, data[j])

fclose(fo)

Expand Down Expand Up @@ -187,26 +188,11 @@ cpdef tuple read_csv(char* csv_file, char* delimiters):

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void write_dense(char* output_file, str[:] barcodes, str[:] features, object data, int[:] indices, int[:] indptr, int M, int N, int precision = 2):
cpdef void write_dense(char* output_file, str[:] barcodes, str[:] features, data_type[:] data, indices_type[:] indices, indptr_type[:] indptr, int M, int N, int precision = 2):
""" Input must be csr_matrix internal representation, gene by cell (X.T.tocsr()); Output will be gene by cell
"""
cdef FILE* fo = fopen(output_file, "w")
cdef str fmt_str = ""
cdef char is_real = 0

cdef int[:] data_int
cdef float[:] data_float

if data.dtype.kind == 'f':
fmt_str = f"\t%.{precision}f"
is_real = 1
data_float = data
elif data.dtype.kind == 'i':
fmt_str = "\t%d"
is_real = 0
data_int = data
else:
raise ValueError(f"Detected unknown dtype: {data.dtype}!")
cdef str fmt_str = f"\t%.{precision}f" if (data_type is float) or (data_type is double) else "\t%d"

cdef const char* fmt = fmt_str
cdef Py_ssize_t i, j, k, fr
Expand All @@ -222,10 +208,7 @@ cpdef void write_dense(char* output_file, str[:] barcodes, str[:] features, obje
for j in range(indptr[i], indptr[i + 1]):
for k in range(fr, indices[j]):
fprintf(fo, "\t0")
if is_real:
fprintf(fo, fmt, data_float[j])
else:
fprintf(fo, fmt, data_int[j])
fprintf(fo, fmt, data[j])
fr = indices[j] + 1
for k in range(fr, N):
fprintf(fo, "\t0")
Expand Down

0 comments on commit 405769d

Please sign in to comment.