In [1]:
import pandas as pd
import numpy as np
import scanpy
import scipy
from sklearn.decomposition import PCA
from functools import partial
# from scipy.spatial.distance import cosine
# from sklearn.preprocessing import StandardScaler



In [2]:
# check arguments
# check that no columns or rows with all zeros
# ensure that factor variables are NOT integers or real
# what if no covariates
# subsetting to most variable genes, iterPCA and stand/cond PCA will be different
# output mean variance relationship
#  MUST ADD VARGENES SUBSET HERE !

class condPCA(object):
    def __init__(self, count_matrix_path, metadata_path, object_columns, vars_to_regress=True, n_PCs=200, random_seed=9989999):
        """
        Parameters
        ----------
        count_matrix:
            Count matrix that must be QC'd

        metadata:
            metadata containing cell type labels named "celltype"

        object_columns:
            columns that will be one hot encoded/columns that are factors 

        vars_to_regress:
            list of variables to regress out

        """
        self.count_matrix = scanpy.read(count_matrix_path) # cells x genes, pd.read_csv(count_matrix_path, sep='\t', header=0, index_col=0)
       
        self.metadata = pd.read_csv(metadata_path, sep='\t', header=0, index_col=0)
        
        if vars_to_regress:
            
            self.vars_to_regress = self.metadata.columns
        
        else: # if vars_to_regress is a list, convert to pandas core Index object
            
            self.vars_to_regress = pd.Index(vars_to_regress)

        # one hot encode necessary metadata variables
        self.object_columns = object_columns # obtain columns that must be one hot encoded
        
        self.metadata[self.object_columns] = self.metadata[self.object_columns].astype(object) # convert these columns to objects

        self.random_seed = random_seed # set random seed
        
        self.n_PCs = n_PCs

    def Normalize(self):
        """ 
        Normalize and take log1p of count data
        """
        
        # update scanpy object to normalize all rows, so every cell sums to 10k
        scanpy.pp.normalize_total(self.count_matrix, target_sum = 10000) 
       
        # log transform
        scanpy.pp.log1p(self.count_matrix) 

    def Standardize(self):
        """ 
        Standardize count data AND metadata
        """
        
        # Standardize count data
        if scipy.sparse.issparse(self.count_matrix.X):
            
            self.count_matrix.X = self.count_matrix.X.todense()
        
        self.standardized_count_data = self._standardize(self.count_matrix.X)

        # Process metadata/covariates for standardization:

        # subset to only variables that you want to regress out
        self.metadata = self.metadata[self.vars_to_regress] 
       
        # WARNING IN FOLLOWING LINE BECAUSE CONVERTING OBJECT THAT LOOKS NUMERIC TO BE ONE HOT ENCODED, this is batch
        self.IterPCA_metadata = pd.get_dummies(self.metadata, drop_first=False)
        
        # Convert factor covariates to dummy variables dropping one column 
        self.metadata = pd.get_dummies(self.metadata, drop_first=True) 
        
        self.standardized_metadata = self._standardize(self.metadata)
    
    def _standardize(self, mat): # simple function performing standardization
        
        mean_vector = np.mean(mat, axis=0)
       
        std_vector = np.std(mat, axis=0)
        
        # standardize by gene
        stand_mat = (mat - mean_vector) / std_vector 
        return stand_mat
    
    def _regress_covariates(self, standardized_metadata, standardized_count_data): # function regressing set of covariates
        
        # append ones to standardized meta for intercept
        standardized_metadata = np.c_[np.ones((standardized_metadata.shape[0], 1)), standardized_metadata] 
        
        # compute inverse of np.matmul(A^T, A) where A is the standardized metadata or covariates
        inv_cov = np.linalg.pinv(np.matmul(standardized_metadata.T, standardized_metadata) ) 
        
        # compute betas per gene
        betas = np.apply_along_axis(self._column_wise_regression, axis=0, arr=standardized_count_data, inv_cov_mat=inv_cov, standardized_metadata_mat=standardized_metadata) 
        
        # compute prediction
        prediction = np.matmul(standardized_metadata, betas) # compute prediction
        
        # compute residual
        residual = standardized_count_data - prediction 
        
        standardized_residual = self._standardize(residual)
        
        return standardized_residual

    def _column_wise_regression(self, column, inv_cov_mat, standardized_metadata_mat): # perform regression of metadata/covariates across all genes
        
        # compute betas for a given gene (dimension of covariates plus 1 for intercept)
        betas = inv_cov_mat @ standardized_metadata_mat.T @ column 
        
        return betas
    
    def _fit_pca(self, mat): # fitting PCA
        
        # instantiate PCA with hyperparameters
        pca = PCA(n_components=self.n_PCs, random_state=self.random_seed) 
        
        # projections (of input data onto eigenvectors)
        pca.fit(mat) 
       
        # retrieve eigenvectors
        gene_loadings = pca.components_ 
        
        cell_embeddings = pca.transform(mat)
       
        eigenvalues = pca.explained_variance_ 
        
        return(cell_embeddings, gene_loadings, eigenvalues)
    
    def _compute_BIC(self, mat): # compute BIC significant PCs
        return 0
    
    def _fit_model(self, standardized_metadata, standardized_count_data, iterPCA=False, iterPCA_genenames=False, iterPCA_cellnames=False): # regress out covariates and then input into PCA
        #standardized_count_data= self.standardized_count_data
        # regress out covariates (including celltype) and retrieve standardized residual
        standardized_residual = self._regress_covariates(standardized_metadata = standardized_metadata, standardized_count_data= standardized_count_data)
        
        # if not iterative PCA, able to add gene names and cell names here, but must subset if IterPCA
        if not iterPCA: 
            # return standardized residual as a dataframe with gene and cell names:

            #  MUST ADD VARGENES SUBSET HERE !
            standardized_residual = pd.DataFrame(standardized_residual, index = list(self.count_matrix.obs_names), columns = list(self.count_matrix.var_names))

        if iterPCA:
            # return standardized residual as a dataframe with gene and cell names of the given subset:
            standardized_residual = pd.DataFrame(standardized_residual, index = list(iterPCA_cellnames), columns = list(iterPCA_genenames))

        # perform PCA on residualized matrix
        return( self._fit_pca(standardized_residual) )

    def _mapping_IterPCA_subset_dataframes_to_PCA(self, metadata, CT_column): 
        
        # extract indices of the cells that belong to the particular cell type of interest (indicated by CT_column, which is a column name)
        indices_given_ct = self.CT_indices_df[CT_column]
        
        # subset the count data to the cells belonging to the cell type
        count_data_subset_to_CT = self.standardized_count_data[indices_given_ct]
        
        # subset the metadata to the cells belonging to the cell type
        metadata_subset_to_CT = metadata[indices_given_ct]

        # Re-standardize count databecause it has just been subset
        count_data_subset_to_CT = self._standardize(count_data_subset_to_CT)
        # Re-standardize metadata because it has just been subset
        metadata_subset_to_CT = self._standardize(metadata_subset_to_CT)

        # extract the cell names or barcodes of the cells belonging to the cell type of interest
        cellnames = self.count_matrix.obs_names[indices_given_ct]

        # extract the gene names of the genes belonging to the most variable genes within that subset
        #  MUST ADD VARGENES SUBSET HERE !
        genenames = self.count_matrix.var_names

        # fit the given model by regressing out covariates and performing PCA
        return self._fit_model(standardized_metadata=metadata_subset_to_CT,standardized_count_data = count_data_subset_to_CT, iterPCA=True, iterPCA_genenames=genenames, iterPCA_cellnames = cellnames)

    def CondPCA_fit(self):
       
        # fit linear model (regress out covariates) and fit PCA -- covariates contain cell type
        self.COND_cell_embeddings, self.COND_gene_loadings, self.COND_eigenvalues = self._fit_model(standardized_metadata=self.standardized_metadata,standardized_count_data= self.standardized_count_data)

        # compute BIC

    def StandardPCA_fit(self):
        
        # remove celltype from covariate space
        metadata_minus_celltype = self.standardized_metadata.drop(columns = self.standardized_metadata.filter(like="celltype", axis=1).columns )
        
        # fit linear model (regress out covariates) and fit PCA -- covariates do not contain cell type
        self.STANDARD_cell_embeddings, self.STANDARD_gene_loadings, self.STANDARD_eigenvalues = self._fit_model(standardized_metadata=metadata_minus_celltype,standardized_count_data= self.standardized_count_data)
    
    def Iter_PCA_fit(self):

        # remove celltype from covariate space
        metadata_minus_celltype = self.standardized_metadata.drop(columns = self.standardized_metadata.filter(like="celltype", axis=1).columns )
        
        # get dataframe with boolean indices for each cell type
        self.CT_indices_df = self.IterPCA_metadata.filter(like="celltype", axis=1).astype(bool) 
        
        # get the name of the columns that indicate a cell type
        celltype_colnames = self.CT_indices_df.columns 

        # Create a partially applied function with the df to subset as the counts matrix
        subset_counts_matrix_fcn = partial(self._mapping_IterPCA_subset_dataframes_to_PCA, self.standardized_count_data)

        # output a list of PCA outputs
        self.result = list(map(subset_counts_matrix_fcn, celltype_colnames)) 

        # TO DO: MAKE SURE THAT THESE OUTPUTS ARE CORRECT?, perhaps make this variable a dictionary with cell type of interest
       

        



        

        


        
        

In [3]:
# instantiate class
test = condPCA(count_matrix_path="/Users/shayecarver/condPCA/final_method/test_matrix.txt", metadata_path="/Users/shayecarver/condPCA/final_method/test_metadata.txt", object_columns=['Batch', 'Sex','celltype'])
test.Normalize()
test.Standardize()
#test.CondPCA_fit()
#test.StandardPCA_fit()
test.Iter_PCA_fit()

  self.IterPCA_metadata = pd.get_dummies(self.metadata, drop_first=False)
  self.metadata = pd.get_dummies(self.metadata, drop_first=True)


(416,)
Index(['ATTTACCAGTTTAGGA-18', 'AGTGACTAGCAGCCCT-13', 'GGGACAACACGACTAT-1',
       'TTATTGCCAGGACTTT-5', 'GCATGATAGGGAGATA-8', 'ACTTAGGTCGCCTTTG-6',
       'CGGAACCTCAACCCGG-3', 'CTGAGGCGTCCATAGT-12', 'GAGACCCGTTTCGCTC-12',
       'TATTTCGGTTCCATTT-14',
       ...
       'GTGTTAGAGACCATAA-7', 'GGGCCATAGAGCCGAT-8', 'AATAGAGTCCTCTGCA-2',
       'AACTTCTGTAATTAGG-10', 'TTACGCCCAGCGAGTA-3', 'CAAGGGATCAGACCTA-7',
       'TAGACTGTCATGCATG-9', 'AGCCACGTCCACAAGT-10', 'ACGGGTCAGCTTCATG-9',
       'TGAATCGGTCAAAGTA-3'],
      dtype='object', length=416)
(532,)
Index(['TACTGCCGTTTACGAC-5', 'GTGTTAGGTTCAGCGC-6', 'GCGGAAATCACTCTTA-1',
       'CCTCTCCTCTGCGTCT-9', 'TTGCATTTCATGCGGC-7', 'ACGATGTAGTTGAAGT-6',
       'CTTCCTTGTCTTACAG-9', 'CTGCGAGCAGTGTGCC-9', 'CAATTTCGTATCGCTA-6',
       'GCAACCGCATAGTCGT-16',
       ...
       'CACTGAACAGATGCGA-4', 'CACGGGTAGCTAGATA-9', 'CCTCATGCATGTGGTT-2',
       'AAGTTCGAGACGCAGT-1', 'TGTTTGTTCAAGAGTA-1', 'TTCTAACAGTTATGGA-6',
       'TCATGGAAGGCCTAGA-7', 'T

In [118]:
sum = 0
for i in range(len(test.result)):
    sum += test.result[i].shape[0]

sum

5000

In [125]:
test.standardized_count_data

array([[-0.22454718, -0.19931257, -0.15417232, ..., -0.15137516,
        -0.40648767, -0.6635374 ],
       [-0.22454718, -0.19931257, -0.15417232, ..., -0.15137516,
        -0.40648767,  1.6786541 ],
       [-0.22454718, -0.19931257, -0.15417232, ..., -0.15137516,
        -0.40648767,  1.893048  ],
       ...,
       [-0.22454718, -0.19931257, -0.15417232, ..., -0.15137516,
        -0.40648767,  2.2684371 ],
       [-0.22454718, -0.19931257, -0.15417232, ..., -0.15137516,
        -0.40648767, -0.6635374 ],
       [-0.22454718, -0.19931257, -0.15417232, ..., -0.15137516,
         1.3463936 ,  1.0185403 ]], dtype=float32)

In [7]:
test.result[3]

Unnamed: 0,AC109466.1,AC008415.1,AC016687.2,LINC02055,RELN,FGF13,AJ009632.2,HTR2C,TSHZ2,ADAMTSL1,...,L1TD1,AL160262.1,PROX1-AS1,AL162718.1,BX664727.3,RIPOR3,MYO3A,AC064875.1,STAMBPL1,GABRB1
ACTCCCAGTTATTCTC-12,-0.428016,2.071507,-0.811296,-1.208652,0.450364,-0.733668,-0.556622,-0.033418,-0.541458,-0.222393,...,0.267474,0.294073,-1.266965,0.289807,0.493817,0.203516,2.554037,0.276916,1.706070,-0.017939
ATGAGGGAGTAACCGG-14,0.927202,0.851534,-1.359764,1.623090,0.332552,-0.260247,-0.038106,1.716419,0.199485,-0.681727,...,-0.011564,0.033929,0.316217,-0.058185,0.082864,-0.056669,-0.545206,-0.101268,-0.090013,0.312183
TCCCACACAAACCGGA-8,1.168487,-1.318341,-0.519177,0.593366,0.435637,0.327034,2.087128,-0.881824,-0.185236,0.110624,...,0.313980,0.424144,1.353934,0.326437,1.156645,0.367843,-0.535740,0.423988,-0.974489,1.599660
GTTCGCTGTCACGACC-4,0.722110,0.534900,0.613528,0.289965,1.289771,1.237919,0.137576,1.069510,1.111414,-1.772646,...,1.492140,1.448458,1.659927,1.480304,0.162403,1.504440,-1.327116,1.600562,-2.200817,1.170501
GATCCCTGTGGAAATT-8,-1.180021,0.376583,0.059098,0.731275,0.273647,-0.296203,-0.133622,1.249796,1.823860,0.891493,...,0.949566,0.895654,-0.335681,0.784321,-0.566707,0.833437,0.297288,0.823183,0.978955,-0.183000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTGCGTCTCCGTGTCT-7,-0.995036,2.611647,-0.238982,0.018743,0.317826,-0.991353,-3.292476,0.921039,0.783690,1.546044,...,0.406993,0.440403,-1.479830,0.509591,-1.070456,0.463701,0.645645,0.487019,1.141742,-0.909269
GGTAACTTCCGGTAGC-6,-1.718891,-1.150711,0.017367,-2.164825,0.111656,-0.404071,-0.283718,1.886100,1.310899,1.695327,...,0.701533,0.733064,-0.282465,0.729375,-0.354602,0.669110,-0.895456,0.697121,-0.806275,0.411220
GGAAGTGGTGTGTCCG-15,-1.075464,0.655966,0.530065,1.117422,0.700713,0.890345,-1.375331,1.292216,0.384720,-0.210910,...,1.089085,1.041985,0.835075,1.059051,1.156645,1.052540,-0.005631,1.033285,-0.477988,-0.133482
CCCAACTAGAGCCGAT-14,0.018361,-0.973768,-1.496881,0.349726,-0.727751,1.154022,-0.051751,-2.461979,-1.695619,-1.175512,...,-0.600644,-0.616429,0.196481,-0.607646,-1.149995,-0.590733,-0.717491,-0.689555,-0.594652,-0.282037
