# custom class for storing and processing single cell datasets

In [None]:
import os, scprep, magic
import pandas as pd
import scipy as sp

In [None]:
from scipy.sparse import csr_matrix
from scprep.filter import filter_rare_genes, filter_library_size
from scprep.normalize import library_size_normalize
from scprep.transform import log
from magic import MAGIC

In [1]:
class dataset:
    
    
    def __init__(self,name):
        
        # whenever updating this, need to update save_dataset func
        
        self.name = name
        self.raw_counts = None
        self.raw_counts_data = None
        self.raw_counts_indices = None
        self.raw_counts_indptr = None
        self.raw_counts_shape = None
        self.normalized = None
        self.imputed = None
        self.pseudotimes = None
        self.clusters = None
#         self.pca_embedding = None
#         self.umap_embedding = None
#         self.tsne_embedding = None
        
        
    def raw_counts_from_sparse(self, cell_names, gene_names, 
                               data, indices, indptr, shape, dtype):
        
        self.raw_counts_data = data
        self.raw_counts_indices = indices
        self.raw_counts_indptr = indptr
        self.raw_counts_shape = shape
    
        self.raw_counts = pd.DataFrame(
                            data=csr_matrix(
                                (self.raw_counts_data,
                                 self.raw_counts_indices,
                                 self.raw_counts_indptr),
                                shape=self.raw_counts_shape,
                                dtype=dtype).toarray(),
                            index=cell_names,    
                            columns=gene_names)
    
    
    def scprep_preprocess(self,cutoff):
        
        self.normalized = filter_rare_genes(self.raw_counts)
        self.normalized = filter_library_size(self.normalized,
                                              cutoff=cutoff)
        self.normalized = library_size_normalize(self.normalized)
        self.normalized = log(self.normalized)
        
        
        
    def magic_imputation(self,genes):
        
        magic_op = MAGIC(t='auto',
                         verbose=False,
                         random_state=0)
        
        self.imputed = magic_op.fit_transform(self.normalized,
                                              genes=genes)