# custom class for storing and processing single cell datasets

In [None]:
import os, scprep, magic
import pandas as pd
import numpy as np
import scipy as sp

In [None]:
from scipy.sparse import csr_matrix
from scprep.filter import filter_rare_genes, filter_library_size
from scprep.normalize import library_size_normalize
from scprep.transform import log
from magic import MAGIC

In [1]:
class dataset:
    
    
    def __init__(self,name):
        
        self.name = name
        self.raw_counts = None
        self.normalized = None
        self.imputed = None
        self.pseudotimes = None
        self.clusters = None
        self.pt_binned = None
        
#         self.pca_embedding = None
#         self.umap_embedding = None
#         self.tsne_embedding = None
        
        
    def raw_counts_from_sparse_matrix(
        self, cell_names, gene_names, data, 
        indices, indptr, shape, dtype):

        self.raw_counts = pd.DataFrame(
                            data=csr_matrix(
                                (data,indices,indptr),
                                shape=shape,
                                dtype=dtype).toarray(),
                            index=cell_names,    
                            columns=gene_names)
    
    
    def scprep_preprocessing(self,cutoff):
        """
        Preprocess raw counts with scprep:
    
        * filter rare genes 
        * filter cells with low counts
        * normalize library sizes 
        * log scale
    
        """
        
        self.normalized = filter_rare_genes(
                                self.raw_counts)
        self.normalized = filter_library_size(
                                self.normalized,
                                cutoff=cutoff)
        self.normalized = library_size_normalize(
                                self.normalized)
        self.normalized = log(self.normalized)
        
        
        
    def magic_imputation(self,genes):
        """
        Impute missing expression from 
        normalized data with MAGIC 
    
        """
        
        magic_op = MAGIC(t='auto',
                    verbose=False,
                    random_state=0)
        
        self.imputed = magic_op.fit_transform(
                                self.normalized,
                                genes=genes)
        
    
    def pseudotime_binning(self,data,bin_size):
        
        if data == 'raw_counts':
            X = self.raw_counts
        elif data == 'normalized':
            X = self.normalized
        elif data == 'imputed':
            X = self.imputed
        else:
            X = self.raw_counts
            
        self.pt_binned = pd.DataFrame(columns=X.columns)
        
        bins = np.vstack((np.arange(self.pseudotimes.min(), self.pseudotimes.max()-bin_size, bin_size),
                          np.arange(bin_size, self.pseudotimes.max(), bin_size))).T
        
        for i in bins:
        
            idxs = self.pseudotimes[(self.pseudotimes > i[0]) & (self.pseudotimes < i[1])].index
    
            if idxs.shape[0] > 0:
            
                self.pt_binned.loc[np.mean(i),X.columns] = X.loc[idxs].mean(axis=0)
        
        