In [16]:
import pandas as pd
import numpy as np
import scanpy
import scipy
from sklearn.decomposition import PCA
# from scipy.spatial.distance import cosine
# from sklearn.preprocessing import StandardScaler



In [42]:
# check arguments
# check that no columns or rows with all zeros
# ensure that factor variables are NOT integers or real
# what if no covariates
# subsetting to most variable genes

class condPCA(object):
    def __init__(self, count_matrix_path, metadata_path, object_columns, vars_to_regress=True, n_PCs=200, random_seed=9989999):
        """
        Parameters
        ----------
        count_matrix:
            Count matrix that must be QC'd

        metadata:
            metadata containing cell type labels named "celltype"

        object_columns:
            columns that will be one hot encoded/columns that are factors 

        vars_to_regress:
            list of variables to regress out

        """
        self.count_matrix = scanpy.read(count_matrix_path) # cells x genes, pd.read_csv(count_matrix_path, sep='\t', header=0, index_col=0)
        self.metadata = pd.read_csv(metadata_path, sep='\t', header=0, index_col=0)
        if vars_to_regress:
            self.vars_to_regress = self.metadata.columns
        else: # if vars_to_regress is a list, convert to pandas core Index object
            self.vars_to_regress = pd.Index(vars_to_regress)

        # one hot encode necessary metadata variables
        self.object_columns = object_columns # obtain columns that must be one hot encoded
        self.metadata[self.object_columns] = self.metadata[self.object_columns].astype(object) # convert these columns to objects

        self.random_seed = random_seed # set random seed
        self.n_PCs = n_PCs

        

    def Normalize(self):
        """ 
        Normalize and take log1p of count data
        """
        scanpy.pp.normalize_total(self.count_matrix, target_sum = 10000) # update scanpy object to normalize all rows, so every cell sums to 10k
        scanpy.pp.log1p(self.count_matrix) # log transform

    def Standardize(self):
        """ 
        Standardize count data AND metadata
        """
        # Standardize count data
        if scipy.sparse.issparse(self.count_matrix.X):
            self.count_matrix.X = self.count_matrix.X.todense()
        self.standardized_count_data = self._standardize(self.count_matrix.X)

        # Process metadata/covariates for standardization:
        self.metadata = self.metadata[self.vars_to_regress] # subset to only variables that you want to regress out
        # WARNING IN FOLLOWING LINE BECAUSE CONVERTING OBJECT THAT LOOKS NUMERIC TO BE ONE HOT ENCODED, this is batch
        self.metadata = pd.get_dummies(self.metadata, drop_first=True) # Convert factor covariates to dummy variables dropping one column 
        self.standardized_metadata = self._standardize(self.metadata)
    
    def _standardize(self, mat): # simple function performing standardization
        mean_vector = np.mean(mat, axis=0)
        std_vector = np.std(mat, axis=0)
        stand_mat = (mat - mean_vector) / std_vector # standardize by gene
        return stand_mat
    
    def _regress_covariates(self, standardized_metadata, standardized_count_data): # function regressing set of covariates
        standardized_metadata = np.c_[np.ones((standardized_metadata.shape[0], 1)), standardized_metadata] # append ones to standardized meta for intercept
        inv_cov = np.linalg.pinv(np.matmul(standardized_metadata.T, standardized_metadata) ) # compute inverse of np.matmul(A^T, A) where A is the standardized metadata or covariates
        betas = np.apply_along_axis(self._column_wise_regression, axis=0, arr=standardized_count_data, inv_cov_mat=inv_cov, standardized_metadata_mat=standardized_metadata) # compute betas per gene
        prediction = np.matmul(standardized_metadata, betas) # compute prediction
        residual = standardized_count_data - prediction # compute residual
        standardized_residual = self._standardize(residual)
        return standardized_residual

    def _column_wise_regression(self, column, inv_cov_mat, standardized_metadata_mat): # perform regression of metadata/covariates across all genes
        betas = inv_cov_mat @ standardized_metadata_mat.T @ column # compute betas for a given gene (dimension of covariates plus 1 for intercept)
        return betas
    
    def _fit_pca(self, mat): # fitting PCA
        pca = PCA(n_components=self.n_PCs, random_state=self.random_seed) # instantiate PCA with hyperparameters
        pca.fit(mat) # projections (of input data onto eigenvectors)
        gene_loadings = pca.components_ # retrieve eigenvectors
        cell_embeddings = pca.transform(mat)
        eigenvalues = pca.explained_variance_ 
        return(cell_embeddings, gene_loadings, eigenvalues)
    
    def _compute_BIC(self, mat): # compute BIC significant PCs
        return 0


    def CondPCA_fit(self):
        # regress out covariates (including celltype) and retrieve standardized residual
        std_resid = self._regress_covariates(standardized_metadata = self.standardized_metadata, standardized_count_data= self.standardized_count_data)
        
        # return standardized residual as a dataframe with gene and cell names:
        standardized_residual = pd.DataFrame(std_resid, index = list(self.count_matrix.obs_names), columns = list(self.count_matrix.var_names))

        # perform PCA on residualized matrix
        self.COND_cell_embeddings, self.COND_gene_loadings, self.COND_eigenvalues = self._fit_pca(standardized_residual)

        # compute BIC

    def StandardPCA_fit(self):
        # remove celltype from covariate space
        metadata_minus_celltype = self.standardized_metadata.drop(columns = self.standardized_metadata.filter(like="celltype", axis=1).columns )
        
        # regress out covariates (not including celltype) and retrieve standardized residual
        std_resid = self._regress_covariates(standardized_metadata = metadata_minus_celltype, standardized_count_data= self.standardized_count_data)

        # return standardized residual as a dataframe with gene and cell names:
        standardized_residual = pd.DataFrame(std_resid, index = list(self.count_matrix.obs_names), columns = list(self.count_matrix.var_names))

        # perform PCA on residualized matrix
        self.STANDARD_cell_embeddings, self.STANDARD_gene_loadings, self.STANDARD_eigenvalues = self._fit_pca(standardized_residual)
    
    def Iter_PCA_fit(self):
        return(0)
        

        


        
        

In [43]:
# instantiate class
test = condPCA(count_matrix_path="/Users/shayecarver/condPCA/final_method/test_matrix.txt", metadata_path="/Users/shayecarver/condPCA/final_method/test_metadata.txt", object_columns=['Batch', 'Sex','celltype'])
test.Normalize()
test.Standardize()
#test.CondPCA_fit()
test.StandardPCA_fit()

  self.metadata = pd.get_dummies(self.metadata, drop_first=True) # Convert factor covariates to dummy variables dropping one column


In [34]:
test.metadata

Unnamed: 0,Age,Batch_2,Batch_3,Sex_M,celltype_EX,celltype_INH,celltype_MG,celltype_ODC,celltype_OPC,celltype_PER.END
GTCATCCTCCACGGAC-18,90,0,1,1,0,0,0,1,0,0
ATTTACCAGTTTAGGA-18,90,0,1,1,0,0,0,0,0,0
AGTGACTAGCAGCCCT-13,90,0,1,1,0,0,0,0,0,0
CAACGGCAGATAGCAT-2,90,0,0,0,0,0,0,0,1,0
TGATTTCCACATGTTG-5,81,1,0,1,0,0,0,1,0,0
...,...,...,...,...,...,...,...,...,...,...
TGACTCCCATGACAGG-11,90,1,0,1,0,0,0,0,0,1
GCAGCTGGTTCCACGG-12,79,0,1,0,0,0,1,0,0,0
CGTAATGTCGTCAAAC-4,86,1,0,1,0,0,0,1,0,0
TCGTCCAGTACCTAGT-1,90,0,1,0,0,0,0,1,0,0


In [47]:
test.STANDARD_eigenvalues

array([219.8883073 ,  91.1088594 ,  66.37504267,  41.77153183,
        31.37351822,  16.86261716,  16.2158281 ,  14.10836702,
         9.14790565,   8.3455519 ,   7.91771426,   7.64267714,
         7.34903443,   6.15274656,   5.86787993,   5.42548339,
         5.23741411,   4.99350152,   4.81599546,   4.68763116,
         4.41007968,   4.31845392,   4.18999522,   3.94540678,
         3.73653096,   3.50312271,   3.45035015,   3.4410275 ,
         3.30020412,   3.23332113,   3.17420468,   3.12376112,
         2.8784412 ,   2.8440745 ,   2.82569171,   2.78582364,
         2.66910726,   2.59472415,   2.57498315,   2.56014301,
         2.52140779,   2.49207228,   2.48632381,   2.43974997,
         2.4023973 ,   2.38938701,   2.37703771,   2.34056705,
         2.31960555,   2.30826656,   2.28001371,   2.26116096,
         2.2454055 ,   2.22686031,   2.20996552,   2.19304051,
         2.18344586,   2.16415417,   2.15136226,   2.14316954,
         2.12137783,   2.11394327,   2.09212521,   2.07