In [5]:
import sys
import yaml
import torch
import logging
from pathlib import Path

# Add the path to the directory containing the omnicell package
# Assuming the omnicell package is in the parent directory of your notebook
sys.path.append('..')  # Adjust this path as needed

import yaml
import torch
import logging
from pathlib import Path
from omnicell.config.config import Config, ETLConfig, ModelConfig, DatasplitConfig, EvalConfig, EmbeddingConfig
from omnicell.data.loader import DataLoader
from omnicell.constants import PERT_KEY, GENE_EMBEDDING_KEY, CONTROL_PERT
from train import get_model

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Configure paths
MODEL_CONFIG = ModelConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/models/linear_mean_model.yaml")
ETL_CONFIG = ETLConfig(name = "no_preprocessing", log1p = False, drop_unmatched_perts = True)
EMBEDDING_CONFIG = EmbeddingConfig(pert_embedding='GenePT')

SPLIT_CONFIG = DatasplitConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/splits/repogle_k562_essential_raw/random_splits/rs_accP_k562_ood_ss:ns_20_2_most_pert_0.1/split_0/split_config.yaml")
EVAL_CONFIG = EvalConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/splits/repogle_k562_essential_raw/random_splits/rs_accP_k562_ood_ss:ns_20_2_most_pert_0.1/split_0/eval_config.yaml")  # Set this if you want to run evaluations

# Load configurations
config = Config(model_config=MODEL_CONFIG,
                 etl_config=ETL_CONFIG, 
                 datasplit_config=SPLIT_CONFIG, 
                 eval_config=EVAL_CONFIG)




#Alternatively you can initialize the config objects manually as follows:
# etl_config = ETLConfig(name = XXX, log1p = False, drop_unmatched_perts = False, ...)
# model_config = ...
# embedding_config = ...
# datasplit_config = ...
# eval_config = ...
# config = Config(etl_config, model_config, datasplit_config, eval_config)


config.etl_config.pert_embedding = 'bioBERT'
config.etl_config.drop_unmatched_perts = True
# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize data loader and load training data
loader = DataLoader(config)
adata, pert_rep_map = loader.get_training_data()

# Get dimensions and perturbation IDs
input_dim = adata.obsm['embedding'].shape[1]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pert_ids = adata.obs[PERT_KEY].unique()
gene_emb_dim = adata.varm[GENE_EMBEDDING_KEY].shape[1] if GENE_EMBEDDING_KEY in adata.varm else None

print(f"Data loaded:")
print(f"- Number of cells: {adata.shape[0]}")
print(f"- Input dimension: {input_dim}")
print(f"- Number of perturbations: {len(pert_ids)}")



2025-01-31 10:30:10,248 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/repogle_k562_essential_raw.yaml
2025-01-31 10:30:10,250 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNB_raw.yaml
2025-01-31 10:30:10,252 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/adamson_INCOMPLETE.yaml
2025-01-31 10:30:10,253 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNB_HVG.yaml
2025-01-31 10:30:10,255 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/kang.yaml
2025-01-31 10:30:10,257 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/essential_gene_knockouts_raw.yaml
2025-01-31 10:30:10,258 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNG_raw_INCOMPLET

Using device: cuda


2025-01-31 10:30:14,807 - INFO - Loaded unpreprocessed data, # of data points: 310385, # of genes: 8563.
2025-01-31 10:30:14,808 - INFO - Preprocessing training data
2025-01-31 10:30:14,810 - INFO - Using identity features for perturbations
2025-01-31 10:30:14,933 - INFO - Removing observations with perturbations not in the dataset as a column
2025-01-31 10:30:15,119 - INFO - Removed 189 perturbations that were not in the dataset columns and 0 perturbations that did not have an embedding for a total of 189 perturbations removed out of an initial 2058 perturbations
  adata.obsm["embedding"] = adata.X.toarray().astype('float32')
2025-01-31 10:30:52,099 - INFO - Doing OOD split


Data loaded:
- Number of cells: 279630
- Input dimension: 8563
- Number of perturbations: 1850


In [6]:

model = get_model(config.model_config.name, config.model_config.parameters, loader, pert_rep_map, input_dim, device, pert_ids, gene_emb_dim)

2025-01-31 10:30:59,352 - INFO - Mean model selected


In [7]:
model.train(adata, None)

  pert_means = df.groupby('perturbation').mean()


In [28]:
from os import listdir
from scipy.sparse import issparse
import anndata
import scanpy as sc
import numpy as np
import pandas as pd

from scipy.stats import pearsonr

import logging

logger = logging.getLogger(__name__)

def r2_mse_filename(pert, cell):
    return f'r2_and_mse_{pert}_{cell}.json'

def c_r_filename(pert, cell):
    return f'c_r_results_{pert}_{cell}.json'

def DEGs_overlap_filename(pert, cell):   
    return f'DEGs_overlaps_{pert}_{cell}.json'


def get_DEG_with_direction(gene, score):
    if score > 0:
        return(f'{gene}+')
    else:
        return(f'{gene}-')
        
def to_dense(X):
    if issparse(X):
        return X.toarray()
    else:
        return np.asarray(X)

def efficient_correlation_pearsonr(U, V):
    """
    Compute squared Pearson correlation between correlation matrices 
    for matrices of different ranks
    U: (n_samples x k1) matrix
    V: (n_samples x k2) matrix
    """
    # Center the data
    U_centered = U - U.mean(axis=1, keepdims=True)
    V_centered = V - V.mean(axis=1, keepdims=True)
    
    # Normalize rows
    U_norms = np.sqrt(np.sum(U_centered * U_centered, axis=1, keepdims=True))
    V_norms = np.sqrt(np.sum(V_centered * V_centered, axis=1, keepdims=True))
    
    U_normalized = U_centered / U_norms
    V_normalized = V_centered / V_norms
    
    # Work with smaller matrices
    UTU = U_normalized.T @ U_normalized  # k1×k1 matrix
    VTV = V_normalized.T @ V_normalized  # k2×k2 matrix
    UV = U_normalized.T @ V_normalized   # k1×k2 matrix
    
    # Compute means efficiently
    mean_UTU = np.sum(UTU) / (U.shape[0] ** 2)
    mean_VTV = np.sum(VTV) / (V.shape[0] ** 2)
    
    # Use trace tricks with different sized matrices
    # For matrices of different sizes, tr(UU^T VV^T) = tr((U^TV)(V^TU)) = sum(UV * VU)
    # print(UV.shape, V_normalized.T.shape, U_normalized.shape)
    trace_UUTVVT = np.sum(UV.T * (V_normalized.T @ U_normalized))
    trace_UUTUUT = np.sum(UTU * UTU)
    trace_VVTVVT = np.sum(VTV * VTV)
    
    n_squared = U.shape[0] ** 2
    
    # Compute correlation
    numerator = trace_UUTVVT - n_squared * mean_UTU * mean_VTV
    denominator = np.sqrt((trace_UUTUUT - n_squared * mean_UTU ** 2) * 
                         (trace_VVTVVT - n_squared * mean_VTV ** 2))
    
    r = numerator / denominator
    return r

def efficient_covariance_pearsonr(U, V):
    """
    Compute squared Pearson correlation between covariance matrices 
    for matrices of different ranks
    U: (n_samples x k1) matrix
    V: (n_samples x k2) matrix
    """
    # Center the data
    U_centered = U - U.mean(axis=1, keepdims=True)
    V_centered = V - V.mean(axis=1, keepdims=True)
    
    # For covariance, we don't normalize by row norms
    # Compute smaller matrices directly
    UTU = U_centered.T @ U_centered  # k1×k1 matrix
    VTV = V_centered.T @ V_centered  # k2×k2 matrix
    UV = U_centered.T @ V_centered   # k1×k2 matrix
    
    # Compute means efficiently
    mean_UTU = np.sum(UTU) / (U.shape[0] ** 2)
    mean_VTV = np.sum(VTV) / (V.shape[0] ** 2)
    
    # Use trace tricks with different sized matrices
    trace_UUTVVT = np.sum(UV.T * (V_centered.T @ U_centered))
    trace_UUTUUT = np.sum(UTU * UTU)
    trace_VVTVVT = np.sum(VTV * VTV)
    
    n_squared = U.shape[0] ** 2
    
    # Compute correlation
    numerator = trace_UUTVVT - n_squared * mean_UTU * mean_VTV
    denominator = np.sqrt((trace_UUTUUT - n_squared * mean_UTU ** 2) * 
                         (trace_VVTVVT - n_squared * mean_VTV ** 2))
    
    r = numerator / denominator
    return r

def get_eval(ctrl_adata, true_adata, pred_adata, DEGs, DEG_vals, pval_threshold, lfc_threshold):
        
    results_dict =  {}
    
    logger.debug(f"Computing R, R2, and MSE metrics")
    ctrl_X = to_dense(ctrl_adata.X)
    true_X = to_dense(true_adata.X)
    pred_X = to_dense(pred_adata.X)

    ctrl_mean = ctrl_X.mean(axis = 0)

    true_mean = true_X.mean(axis = 0)
    true_var = true_X.var(axis = 0)
    
    pred_mean = pred_X.mean(axis = 0)
    pred_var = pred_X.var(axis = 0)
    print("mean and var")
    
    # true_corr_mtx = np.corrcoef(true_X, rowvar=False).flatten()
    # true_cov_mtx = np.cov(true_X, rowvar=False).flatten()
        
    # pred_corr_mtx = np.corrcoef(pred_X, rowvar=False).flatten()
    # pred_cov_mtx = np.cov(pred_X, rowvar=False).flatten()

    true_sub_diff = true_mean - ctrl_mean
    pred_sub_diff = pred_mean - ctrl_mean

    true_diff = np.expm1(true_mean) - np.expm1(ctrl_mean)
    pred_diff = np.expm1(pred_mean) - np.expm1(ctrl_mean)

    results_dict['all_genes_mean_sub_diff_R'] = pearsonr(true_sub_diff, pred_sub_diff)[0]
    results_dict['all_genes_mean_sub_diff_R2'] = pearsonr(true_sub_diff, pred_sub_diff)[0]**2
    results_dict['all_genes_mean_sub_diff_MSE'] = (np.square(true_sub_diff - pred_sub_diff)).mean(axis=0)
    print("mean sub diff")

    results_dict['all_genes_mean_fold_diff_R'] = pearsonr(true_diff, pred_diff)[0]
    results_dict['all_genes_mean_fold_diff_R2'] = pearsonr(true_diff, pred_diff)[0]**2
    results_dict['all_genes_mean_fold_diff_MSE'] = (np.square(true_diff - pred_diff)).mean(axis=0)

    results_dict['all_genes_mean_R'] = pearsonr(true_mean, pred_mean)[0]
    results_dict['all_genes_mean_R2'] = pearsonr(true_mean, pred_mean)[0]**2
    results_dict['all_genes_mean_MSE'] = (np.square(true_mean - pred_mean)).mean(axis=0)

    results_dict['all_genes_var_R'] = pearsonr(true_var, pred_var)[0]
    results_dict['all_genes_var_R2'] = pearsonr(true_var, pred_var)[0]**2
    results_dict['all_genes_var_MSE'] = (np.square(true_var - pred_var)).mean(axis=0)

    corr_r = efficient_correlation_pearsonr(true_X.T, pred_X.T)
    cov_r = efficient_covariance_pearsonr(true_X.T, pred_X.T)

    results_dict['all_genes_corr_mtx_R'] = corr_r # pearsonr(true_corr_mtx.flatten(), pred_corr_mtx.flatten())[0]
    results_dict['all_genes_corr_mtx_R2'] = corr_r**2 # pearsonr(true_corr_mtx.flatten(), pred_corr_mtx.flatten())[0]**2
    # results_dict['all_genes_corr_mtx_MSE'] = (np.square(true_corr_mtx.flatten() - pred_corr_mtx.flatten())).mean(axis=0)

    results_dict['all_genes_cov_mtx_R'] = cov_r
    results_dict['all_genes_cov_mtx_R2'] = cov_r**2
    # results_dict['all_genes_cov_mtx_MSE'] = (np.square(true_cov_mtx.flatten() - pred_cov_mtx.flatten())).mean(axis=0)

    if lfc_threshold:   
        significant_DEGs = DEGs[(DEGs['pvals_adj'] < pval_threshold) & (abs(DEGs) > lfc_threshold)]
    else:
        significant_DEGs = DEGs[DEGs['pvals_adj'] < pval_threshold]
    num_DEGs = len(significant_DEGs)
    DEG_vals.insert(0, num_DEGs)


    logger.debug(f"Significant DEGs {significant_DEGs}")
    
    for val in DEG_vals:

        logger.debug(f"Computing R, R2, and MSE metrics for top {val} DEGs")

        #If val == 1 we can't
        if ((val > num_DEGs) or (val == 0) or (val == 1)):
            results_dict[f'Top_{val}_DEGs_sub_diff_mean_R'] = None
            results_dict[f'Top_{val}_DEGs_sub_diff_mean_R2'] = None
            results_dict[f'Top_{val}_DEGs_sub_diff_mean_MSE'] = None
            
            results_dict[f'Top_{val}_DEGs_fold_diff_mean_R'] = None
            results_dict[f'Top_{val}_DEGs_fold_diff_mean_R2'] = None
            results_dict[f'Top_{val}_DEGs_fold_diff_mean_MSE'] = None
            
            results_dict[f'Top_{val}_DEGs_mean_R'] = None
            results_dict[f'Top_{val}_DEGs_mean_R2'] = None
            results_dict[f'Top_{val}_DEGs_mean_MSE'] = None

            results_dict[f'Top_{val}_DEGs_var_R'] = None
            results_dict[f'Top_{val}_DEGs_var_R2'] = None
            results_dict[f'Top_{val}_DEGs_var_MSE'] = None
            
            results_dict[f'Top_{val}_DEGs_corr_mtx_R'] = None
            results_dict[f'Top_{val}_DEGs_corr_mtx_R2'] = None
            results_dict[f'Top_{val}_DEGs_corr_mtx_MSE'] = None
            
            results_dict[f'Top_{val}_DEGs_cov_mtx_R'] = None
            results_dict[f'Top_{val}_DEGs_cov_mtx_R2'] = None
            results_dict[f'Top_{val}_DEGs_cov_mtx_MSE'] = None
        
        else:
            top_DEGs = significant_DEGs[0:val].index.map(int)

            logger.debug(f"Top DEGs: {top_DEGs}")


            #Reshape --> If there is a single gene, the shape is (1,) and we need to reshape it to (1,1)

            ctrl_mean = to_dense(ctrl_X[:,top_DEGs]).mean(axis = 0)
            true_mean = to_dense(true_X[:,top_DEGs]).mean(axis = 0)

            logger.debug(f"Shape ctrl_adata with top DEGs: {ctrl_adata[:,top_DEGs].X.shape}, shape true_adata with top DEGs: {true_adata[:,top_DEGs].X.shape}")


            true_var = to_dense(true_X[:,top_DEGs]).var(axis = 0)
            true_corr_mtx = np.corrcoef(to_dense(true_X[:,top_DEGs]), rowvar=False).flatten()
            true_cov_mtx = np.cov(to_dense(true_X[:,top_DEGs]), rowvar=False).flatten()

            pred_mean = to_dense(pred_X[:,top_DEGs]).mean(axis = 0)
            logger.debug(f"Shape of true_mean shape: {true_mean.shape}, ctrl_mean shape: {ctrl_mean.shape}, pred_mean shape: {pred_mean.shape}")

            pred_var = to_dense(pred_X[:,top_DEGs]).var(axis = 0)
            pred_corr_mtx = np.corrcoef(to_dense(pred_X[:,top_DEGs]), rowvar=False).flatten()
            pred_cov_mtx = np.cov(to_dense(pred_X[:,top_DEGs]), rowvar=False).flatten()

            logger.debug(f"Shape of true_var shape: {true_var.shape}, pred_var shape: {pred_var.shape}")

            true_sub_diff = true_mean - ctrl_mean
            pred_sub_diff = pred_mean - ctrl_mean
        
            # inverse log1p to get sub change
            true_diff = np.expm1(true_mean) - np.expm1(ctrl_mean)
            pred_diff = np.expm1(pred_mean) - np.expm1(ctrl_mean)

            results_dict[f'Top_{val}_DEGs_sub_diff_R'] = pearsonr(true_sub_diff, pred_sub_diff)[0]
            results_dict[f'Top_{val}_DEGs_sub_diff_R2'] = pearsonr(true_sub_diff, pred_sub_diff)[0]**2
            results_dict[f'Top_{val}_DEGs_sub_diff_MSE'] = (np.square(true_sub_diff - pred_sub_diff)).mean(axis=0)
        
            results_dict[f'Top_{val}_DEGs_fold_diff_R'] = pearsonr(true_diff, pred_diff)[0]
            results_dict[f'Top_{val}_DEGs_fold_diff_R2'] = pearsonr(true_diff, pred_diff)[0]**2
            results_dict[f'Top_{val}_DEGs_fold_diff_MSE'] = (np.square(true_diff - pred_diff)).mean(axis=0)
    
            results_dict[f'Top_{val}_DEGs_mean_R'] = pearsonr(true_mean, pred_mean)[0]
            results_dict[f'Top_{val}_DEGs_mean_R2'] = pearsonr(true_mean, pred_mean)[0]**2
            results_dict[f'Top_{val}_DEGs_mean_MSE'] = (np.square(true_mean - pred_mean)).mean(axis=0)

            results_dict[f'Top_{val}_DEGs_var_R'] = pearsonr(true_var, pred_var)[0]
            results_dict[f'Top_{val}_DEGs_var_R2'] = pearsonr(true_var, pred_var)[0]**2
            results_dict[f'Top_{val}_DEGs_var_MSE'] = (np.square(true_var - pred_var)).mean(axis=0)

            results_dict[f'Top_{val}_DEGs_corr_mtx_R'] = pearsonr(true_corr_mtx.flatten(), pred_corr_mtx.flatten())[0]
            results_dict[f'Top_{val}_DEGs_corr_mtx_R2'] = pearsonr(true_corr_mtx.flatten(), pred_corr_mtx.flatten())[0]**2
            results_dict[f'Top_{val}_DEGs_corr_mtx_MSE'] = (np.square(true_corr_mtx.flatten() - pred_corr_mtx.flatten())).mean(axis=0)

            results_dict[f'Top_{val}_DEGs_cov_mtx_R'] = pearsonr(true_cov_mtx.flatten(), pred_cov_mtx.flatten())[0]
            results_dict[f'Top_{val}_DEGs_cov_mtx_R2'] = pearsonr(true_cov_mtx.flatten(), pred_cov_mtx.flatten())[0]**2
            results_dict[f'Top_{val}_DEGs_cov_mtx_MSE'] = (np.square(true_cov_mtx.flatten() - pred_cov_mtx.flatten())).mean(axis=0)

    return results_dict

In [29]:
r2_and_mse = get_eval(control, true_pert, pred_pert, true_DEGs_df, [100,50,20], pval_threshold, log_fold_change_threshold)


mean and var
mean sub diff


In [30]:
from os import listdir
from scipy.sparse import issparse
import anndata
import scanpy as sc
import numpy as np
import pandas as pd

from scipy.stats import pearsonr

import logging

logger = logging.getLogger(__name__)

def r2_mse_filename(pert, cell):
    return f'r2_and_mse_{pert}_{cell}.json'

def c_r_filename(pert, cell):
    return f'c_r_results_{pert}_{cell}.json'

def DEGs_overlap_filename(pert, cell):   
    return f'DEGs_overlaps_{pert}_{cell}.json'


def get_DEG_with_direction(gene, score):
    if score > 0:
        return(f'{gene}+')
    else:
        return(f'{gene}-')
        
def to_dense(X):
    if issparse(X):
        return X.toarray()
    else:
        return np.asarray(X)

def get_DEGs(control_adata, target_adata):
    temp_concat = anndata.concat([control_adata, target_adata], label = 'batch')
    sc.tl.rank_genes_groups(
        temp_concat, 'batch', method='wilcoxon', 
        groups = ['1'], ref = '0', rankby_abs = True, tie_correct=True
    )

    rankings = temp_concat.uns['rank_genes_groups']
    result_df = pd.DataFrame({'scores': rankings['scores']['1'],
                     'pvals_adj': rankings['pvals_adj']['1'],
                     'lfc': rankings['logfoldchanges']['1']},
                    index = rankings['names']['1'])
    return result_df



def get_eval(ctrl_adata, true_adata, pred_adata, DEGs, DEG_vals, pval_threshold, lfc_threshold):
        
    results_dict =  {}
    
    logger.debug(f"Computing R, R2, and MSE metrics")
    ctrl_X = to_dense(ctrl_adata.X)
    true_X = to_dense(true_adata.X)
    pred_X = to_dense(pred_adata.X)

    ctrl_mean = ctrl_X.mean(axis = 0)

    true_mean = true_X.mean(axis = 0)
    true_var = true_X.var(axis = 0)
    
    pred_mean = pred_X.mean(axis = 0)
    pred_var = pred_X.var(axis = 0)
    
    true_corr_mtx = np.corrcoef(true_X, rowvar=False).flatten()
    true_cov_mtx = np.cov(true_X, rowvar=False).flatten()
        
    pred_corr_mtx = np.corrcoef(pred_X, rowvar=False).flatten()
    pred_cov_mtx = np.cov(pred_X, rowvar=False).flatten()

    true_sub_diff = true_mean - ctrl_mean
    pred_sub_diff = pred_mean - ctrl_mean

    true_diff = np.expm1(true_mean) - np.expm1(ctrl_mean)
    pred_diff = np.expm1(pred_mean) - np.expm1(ctrl_mean)

    results_dict['all_genes_mean_sub_diff_R'] = pearsonr(true_sub_diff, pred_sub_diff)[0]
    results_dict['all_genes_mean_sub_diff_R2'] = pearsonr(true_sub_diff, pred_sub_diff)[0]**2
    results_dict['all_genes_mean_sub_diff_MSE'] = (np.square(true_sub_diff - pred_sub_diff)).mean(axis=0)

    results_dict['all_genes_mean_fold_diff_R'] = pearsonr(true_diff, pred_diff)[0]
    results_dict['all_genes_mean_fold_diff_R2'] = pearsonr(true_diff, pred_diff)[0]**2
    results_dict['all_genes_mean_fold_diff_MSE'] = (np.square(true_diff - pred_diff)).mean(axis=0)
    
    results_dict['all_genes_mean_R'] = pearsonr(true_mean, pred_mean)[0]
    results_dict['all_genes_mean_R2'] = pearsonr(true_mean, pred_mean)[0]**2
    results_dict['all_genes_mean_MSE'] = (np.square(true_mean - pred_mean)).mean(axis=0)

    results_dict['all_genes_var_R'] = pearsonr(true_var, pred_var)[0]
    results_dict['all_genes_var_R2'] = pearsonr(true_var, pred_var)[0]**2
    results_dict['all_genes_var_MSE'] = (np.square(true_var - pred_var)).mean(axis=0)

    results_dict['all_genes_corr_mtx_R'] = pearsonr(true_corr_mtx.flatten(), pred_corr_mtx.flatten())[0]
    results_dict['all_genes_corr_mtx_R2'] = pearsonr(true_corr_mtx.flatten(), pred_corr_mtx.flatten())[0]**2
    results_dict['all_genes_corr_mtx_MSE'] = (np.square(true_corr_mtx.flatten() - pred_corr_mtx.flatten())).mean(axis=0)

    results_dict['all_genes_cov_mtx_R'] = pearsonr(true_cov_mtx.flatten(), pred_cov_mtx.flatten())[0]
    results_dict['all_genes_cov_mtx_R2'] = pearsonr(true_cov_mtx.flatten(), pred_cov_mtx.flatten())[0]**2
    results_dict['all_genes_cov_mtx_MSE'] = (np.square(true_cov_mtx.flatten() - pred_cov_mtx.flatten())).mean(axis=0)

    if lfc_threshold:   
        significant_DEGs = DEGs[(DEGs['pvals_adj'] < pval_threshold) & (abs(DEGs) > lfc_threshold)]
    else:
        significant_DEGs = DEGs[DEGs['pvals_adj'] < pval_threshold]
    num_DEGs = len(significant_DEGs)
    DEG_vals.insert(0, num_DEGs)


    logger.debug(f"Significant DEGs {significant_DEGs}")
    
    for val in DEG_vals:

        logger.debug(f"Computing R, R2, and MSE metrics for top {val} DEGs")

        #If val == 1 we can't
        if ((val > num_DEGs) or (val == 0) or (val == 1)):
            results_dict[f'Top_{val}_DEGs_sub_diff_mean_R'] = None
            results_dict[f'Top_{val}_DEGs_sub_diff_mean_R2'] = None
            results_dict[f'Top_{val}_DEGs_sub_diff_mean_MSE'] = None
            
            results_dict[f'Top_{val}_DEGs_fold_diff_mean_R'] = None
            results_dict[f'Top_{val}_DEGs_fold_diff_mean_R2'] = None
            results_dict[f'Top_{val}_DEGs_fold_diff_mean_MSE'] = None
            
            results_dict[f'Top_{val}_DEGs_mean_R'] = None
            results_dict[f'Top_{val}_DEGs_mean_R2'] = None
            results_dict[f'Top_{val}_DEGs_mean_MSE'] = None

            results_dict[f'Top_{val}_DEGs_var_R'] = None
            results_dict[f'Top_{val}_DEGs_var_R2'] = None
            results_dict[f'Top_{val}_DEGs_var_MSE'] = None
            
            results_dict[f'Top_{val}_DEGs_corr_mtx_R'] = None
            results_dict[f'Top_{val}_DEGs_corr_mtx_R2'] = None
            results_dict[f'Top_{val}_DEGs_corr_mtx_MSE'] = None
            
            results_dict[f'Top_{val}_DEGs_cov_mtx_R'] = None
            results_dict[f'Top_{val}_DEGs_cov_mtx_R2'] = None
            results_dict[f'Top_{val}_DEGs_cov_mtx_MSE'] = None
        
        else:
            top_DEGs = significant_DEGs[0:val].index.map(int)

            logger.debug(f"Top DEGs: {top_DEGs}")


            #Reshape --> If there is a single gene, the shape is (1,) and we need to reshape it to (1,1)

            ctrl_mean = to_dense(ctrl_X[:,top_DEGs]).mean(axis = 0)
            true_mean = to_dense(true_X[:,top_DEGs]).mean(axis = 0)

            logger.debug(f"Shape ctrl_adata with top DEGs: {ctrl_adata[:,top_DEGs].X.shape}, shape true_adata with top DEGs: {true_adata[:,top_DEGs].X.shape}")


            true_var = to_dense(true_X[:,top_DEGs]).var(axis = 0)
            true_corr_mtx = np.corrcoef(to_dense(true_X[:,top_DEGs]), rowvar=False).flatten()
            true_cov_mtx = np.cov(to_dense(true_X[:,top_DEGs]), rowvar=False).flatten()

            pred_mean = to_dense(pred_X[:,top_DEGs]).mean(axis = 0)
            logger.debug(f"Shape of true_mean shape: {true_mean.shape}, ctrl_mean shape: {ctrl_mean.shape}, pred_mean shape: {pred_mean.shape}")

            pred_var = to_dense(pred_X[:,top_DEGs]).var(axis = 0)
            pred_corr_mtx = np.corrcoef(to_dense(pred_X[:,top_DEGs]), rowvar=False).flatten()
            pred_cov_mtx = np.cov(to_dense(pred_X[:,top_DEGs]), rowvar=False).flatten()

            logger.debug(f"Shape of true_var shape: {true_var.shape}, pred_var shape: {pred_var.shape}")

            true_sub_diff = true_mean - ctrl_mean
            pred_sub_diff = pred_mean - ctrl_mean
        
            # inverse log1p to get sub change
            true_diff = np.expm1(true_mean) - np.expm1(ctrl_mean)
            pred_diff = np.expm1(pred_mean) - np.expm1(ctrl_mean)

            results_dict[f'Top_{val}_DEGs_sub_diff_R'] = pearsonr(true_sub_diff, pred_sub_diff)[0]
            results_dict[f'Top_{val}_DEGs_sub_diff_R2'] = pearsonr(true_sub_diff, pred_sub_diff)[0]**2
            results_dict[f'Top_{val}_DEGs_sub_diff_MSE'] = (np.square(true_sub_diff - pred_sub_diff)).mean(axis=0)
        
            results_dict[f'Top_{val}_DEGs_fold_diff_R'] = pearsonr(true_diff, pred_diff)[0]
            results_dict[f'Top_{val}_DEGs_fold_diff_R2'] = pearsonr(true_diff, pred_diff)[0]**2
            results_dict[f'Top_{val}_DEGs_fold_diff_MSE'] = (np.square(true_diff - pred_diff)).mean(axis=0)
    
            results_dict[f'Top_{val}_DEGs_mean_R'] = pearsonr(true_mean, pred_mean)[0]
            results_dict[f'Top_{val}_DEGs_mean_R2'] = pearsonr(true_mean, pred_mean)[0]**2
            results_dict[f'Top_{val}_DEGs_mean_MSE'] = (np.square(true_mean - pred_mean)).mean(axis=0)

            results_dict[f'Top_{val}_DEGs_var_R'] = pearsonr(true_var, pred_var)[0]
            results_dict[f'Top_{val}_DEGs_var_R2'] = pearsonr(true_var, pred_var)[0]**2
            results_dict[f'Top_{val}_DEGs_var_MSE'] = (np.square(true_var - pred_var)).mean(axis=0)

            results_dict[f'Top_{val}_DEGs_corr_mtx_R'] = pearsonr(true_corr_mtx.flatten(), pred_corr_mtx.flatten())[0]
            results_dict[f'Top_{val}_DEGs_corr_mtx_R2'] = pearsonr(true_corr_mtx.flatten(), pred_corr_mtx.flatten())[0]**2
            results_dict[f'Top_{val}_DEGs_corr_mtx_MSE'] = (np.square(true_corr_mtx.flatten() - pred_corr_mtx.flatten())).mean(axis=0)

            results_dict[f'Top_{val}_DEGs_cov_mtx_R'] = pearsonr(true_cov_mtx.flatten(), pred_cov_mtx.flatten())[0]
            results_dict[f'Top_{val}_DEGs_cov_mtx_R2'] = pearsonr(true_cov_mtx.flatten(), pred_cov_mtx.flatten())[0]**2
            results_dict[f'Top_{val}_DEGs_cov_mtx_MSE'] = (np.square(true_cov_mtx.flatten() - pred_cov_mtx.flatten())).mean(axis=0)

    return results_dict

def get_DEG_Coverage_Recall(true_DEGs, pred_DEGs, p_cutoff):
    sig_true_DEGs = true_DEGs[true_DEGs['pvals_adj'] < p_cutoff]
    true_DEGs_with_direction = [get_DEG_with_direction(gene,score) for gene, score in zip(sig_true_DEGs.index, sig_true_DEGs['scores'])]
    sig_pred_DEGs = pred_DEGs[pred_DEGs['pvals_adj'] < p_cutoff]
    pred_DEGs_with_direction = [get_DEG_with_direction(gene,score) for gene, score in zip(sig_pred_DEGs.index, sig_pred_DEGs['scores'])]
    num_true_DEGs = len(true_DEGs_with_direction)
    num_pred_DEGs = len(pred_DEGs_with_direction)
    num_overlapping_DEGs = len(set(true_DEGs_with_direction).intersection(set(pred_DEGs_with_direction)))
    if num_true_DEGs > 0: 
        COVERAGE = num_overlapping_DEGs/num_true_DEGs
    else:
        COVERAGE = None
    if num_pred_DEGs > 0:
        RECALL = num_overlapping_DEGs/num_pred_DEGs
    else:
        RECALL = None
    return COVERAGE, RECALL

def get_DEGs_overlaps(true_DEGs, pred_DEGs, DEG_vals, pval_threshold, lfc_threshold):
    if lfc_threshold:
        significant_true_DEGs = true_DEGs[(true_DEGs['pvals_adj'] < pval_threshold) & (abs(true_DEGs['lfc']) > lfc_threshold)]
        significant_pred_DEGs = pred_DEGs[(pred_DEGs['pvals_adj'] < pval_threshold) & (abs(pred_DEGs['lfc']) > lfc_threshold)]
    else:
        significant_true_DEGs = true_DEGs[true_DEGs['pvals_adj'] < pval_threshold]
        significant_pred_DEGs = pred_DEGs[pred_DEGs['pvals_adj'] < pval_threshold]

    true_DEGs_for_comparison = [get_DEG_with_direction(gene,score) for gene, score in zip(significant_true_DEGs.index, significant_true_DEGs['scores'])]   
    pred_DEGs_for_comparison = [get_DEG_with_direction(gene,score) for gene, score in zip(significant_pred_DEGs.index, significant_pred_DEGs['scores'])]
    
    logger.debug(f"Computing DEG overlaps, # of significant DEGs in true data: {len(true_DEGs_for_comparison)}, # of significant DEGs in pred data: {len(pred_DEGs_for_comparison)}")
    num_DEGs = len(significant_true_DEGs)
    DEG_vals.insert(0, num_DEGs)
    
    results = {}
    for val in DEG_vals:
        if val > num_DEGs:
            results[f'Overlap_in_top_{val}_DEGs'] = None
        else:
            results[f'Overlap_in_top_{val}_DEGs'] = len(set(true_DEGs_for_comparison[0:val]).intersection(set(pred_DEGs_for_comparison[0:val])))

    intersection = len(set(true_DEGs_for_comparison).intersection(set(pred_DEGs_for_comparison)))
    union = len(set(true_DEGs_for_comparison).union(set(pred_DEGs_for_comparison)))
    if union > 0:
        results['Jaccard'] = intersection/union
    else:
        results['Jaccard'] = None
    
    return results

In [31]:
r2_and_mse = get_eval(control, true_pert, pred_pert, true_DEGs_df, [100,50,20], pval_threshold, log_fold_change_threshold)

In [8]:
import numpy as np

logger.info("Running evaluation")

# evaluate each pair of cells and perts
eval_dict = {}
for cell_id, pert_id, ctrl_data, gt_data in loader.get_eval_data():
    logger.debug(f"Making predictions for cell: {cell_id}, pert: {pert_id}")

    preds = model.make_predict(ctrl_data, pert_id, cell_id)
    eval_dict[(cell_id, pert_id)] = (ctrl_data.X.toarray(), gt_data.X.toarray(), preds)
    break
    
if not config.etl_config.log1p:
    for (cell, pert) in eval_dict:  
        ctrl_data, gt_data, pred_pert = eval_dict[(cell, pert)]
        eval_dict[(cell, pert)] =  (np.log1p(ctrl_data), np.log1p(gt_data), np.log1p(pred_pert))


2025-01-31 10:31:28,863 - INFO - Running evaluation
2025-01-31 10:31:28,864 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/repogle_k562_essential_raw.yaml
2025-01-31 10:31:28,866 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNB_raw.yaml
2025-01-31 10:31:28,868 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/adamson_INCOMPLETE.yaml
2025-01-31 10:31:28,869 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNB_HVG.yaml
2025-01-31 10:31:28,871 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/kang.yaml
2025-01-31 10:31:28,874 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/essential_gene_knockouts_raw.yaml
2025-01-31 10:31:28,876 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/

In [10]:
import scanpy as sc
from omnicell.evaluation.utils import get_DEGs, get_eval, get_DEG_Coverage_Recall, get_DEGs_overlaps
pval_threshold = 0.05
log_fold_change_threshold = 0.0

results_dict = {}
for (cell, pert) in eval_dict:  
    ctrl_data, gt_data, pred_pert = eval_dict[(cell, pert)]

    pred_pert = sc.AnnData(X=pred_pert)
    true_pert = sc.AnnData(X=gt_data)
    control = sc.AnnData(X=ctrl_data)

    logger.info(f"Getting ground Truth DEGs for {pert} and {cell}")
    true_DEGs_df = get_DEGs(control, true_pert)
    signif_true_DEG = true_DEGs_df[true_DEGs_df['pvals_adj'] < pval_threshold]

    logger.info(f"Number of significant DEGS from ground truth: {signif_true_DEG.shape[0]}")

    logger.info(f"Getting predicted DEGs for {pert} and {cell}")
    pred_DEGs_df = get_DEGs(control, pred_pert)


    logger.info(f"Getting evaluation metrics for {pert} and {cell}")
    r2_and_mse = get_eval(control, true_pert, pred_pert, true_DEGs_df, [100,50,20], pval_threshold, log_fold_change_threshold)

    logger.info(f"Getting DEG overlaps for {pert} and {cell}")
    DEGs_overlaps = get_DEGs_overlaps(true_DEGs_df, pred_DEGs_df, [100,50,20], pval_threshold, log_fold_change_threshold)

    results_dict[(cell, pert)] = (r2_and_mse, DEGs_overlaps)



2025-01-31 10:41:37,979 - INFO - Getting ground Truth DEGs for RPL15 and k562
  utils.warn_names_duplicates("obs")
2025-01-31 10:41:45,897 - INFO - Number of significant DEGS from ground truth: 2159
2025-01-31 10:41:45,897 - INFO - Getting predicted DEGs for RPL15 and k562
  utils.warn_names_duplicates("obs")
2025-01-31 10:42:04,082 - INFO - Getting evaluation metrics for RPL15 and k562
2025-01-31 10:42:32,070 - INFO - Getting DEG overlaps for RPL15 and k562


In [8]:
for (cell, pert) in results_dict:
    r2_and_mse, DEGs_overlaps = results_dict[(cell, pert)]
    print(f"Cell: {cell}, Pert: {pert}")
    # print(f"R2 and MSE: {r2_and_mse}")
    print(f"DEGs Overlaps: {DEGs_overlaps}")
    print("-"*100)


Cell: k562, Pert: RPL15
DEGs Overlaps: {'Overlap_in_top_2159_DEGs': 875, 'Overlap_in_top_100_DEGs': 4, 'Overlap_in_top_50_DEGs': 0, 'Overlap_in_top_20_DEGs': 0, 'Jaccard': 0.17341105257258155}
----------------------------------------------------------------------------------------------------


In [9]:
r2_and_mse

{'all_genes_mean_sub_diff_R': np.float64(0.6966497841526466),
 'all_genes_mean_sub_diff_R2': np.float64(0.485320921759929),
 'all_genes_mean_sub_diff_MSE': np.float64(0.01054379975115533),
 'all_genes_mean_fold_diff_R': np.float64(0.7028902237471792),
 'all_genes_mean_fold_diff_R2': np.float64(0.4940546666393597),
 'all_genes_mean_fold_diff_MSE': np.float64(7.89763734207286),
 'all_genes_mean_R': np.float64(0.9880260038744728),
 'all_genes_mean_R2': np.float64(0.9761953843321598),
 'all_genes_mean_MSE': np.float64(0.010543799755630833),
 'all_genes_var_R': np.float64(0.8846038081846144),
 'all_genes_var_R2': np.float64(0.7825238974547221),
 'all_genes_var_MSE': np.float64(0.003680699246692525),
 'all_genes_corr_mtx_R': np.float64(0.5764657556881528),
 'all_genes_corr_mtx_R2': np.float64(0.3323127674811131),
 'all_genes_corr_mtx_MSE': np.float64(0.0074322404434879),
 'all_genes_cov_mtx_R': np.float64(0.7999105612729749),
 'all_genes_cov_mtx_R2': np.float64(0.6398569060360457),
 'all_gen

In [57]:
def efficient_correlation_r2_pearson(U, V):
    """
    Compute squared Pearson correlation between correlation matrices 
    without forming full matrices
    """
    n = U.shape[0]
    
    # Normalize U and V
    U_norms = np.sqrt(np.sum(U * U, axis=1, keepdims=True))
    U_normalized = U / U_norms
    
    V_norms = np.sqrt(np.sum(V * V, axis=1, keepdims=True))
    V_normalized = V / V_norms
    
    # Compute correlation matrices implicitly
    UTU = U_normalized.T @ U_normalized
    VTV = V_normalized.T @ V_normalized
    
    # Compute means without forming full matrices
    mean_UTU = np.sum(UTU) / (n * n)
    mean_VTV = np.sum(VTV) / (n * n)
    
    # Center the matrices implicitly
    UTU_centered = UTU - mean_UTU
    VTV_centered = VTV - mean_VTV
    
    # Compute correlation using trace tricks
    numerator = np.sum(UTU_centered * VTV_centered)
    denominator = np.sqrt(np.sum(UTU_centered * UTU_centered) * 
                         np.sum(VTV_centered * VTV_centered))
    
    r = numerator / denominator
    return r * r

def naive_correlation_r2(X1, X2):
    """
    Original implementation using flattened matrices
    """
    true_corr_mtx = np.corrcoef(X1).flatten()
    pred_corr_mtx = np.corrcoef(X2).flatten()
    print(true_corr_mtx.shape, pred_corr_mtx.shape)
    
    return pearsonr(true_corr_mtx, pred_corr_mtx)[0] ** 2

# Test both implementations
np.random.seed(42)
n, k = 2000, 10
U = np.random.randn(n, k)
V = np.random.randn(n, k)
V[:, :1] = U[:, :1]  # Make them partially similar

naive_result = naive_correlation_r2(U, V)
efficient_result = efficient_correlation_r2_pearson(U, V)

print(f"Naive R²: {naive_result:.6f}")
print(f"Efficient R²: {efficient_result:.6f}")

(4000000,) (4000000,)
Naive R²: 0.008483
Efficient R²: 0.994205


In [59]:
def efficient_correlation_r2_pearson(U, V):
    """
    Compute squared Pearson correlation between correlation matrices 
    without forming full matrices, matching np.corrcoef exactly
    """
    # First compute the correlation matrices exactly as numpy does
    true_corr = np.corrcoef(U.T)  # Note: using .T to match np.corrcoef default behavior
    pred_corr = np.corrcoef(V.T)
    
    # Now compute Pearson R² between these matrices without flattening
    mean_true = np.mean(true_corr)
    mean_pred = np.mean(pred_corr)
    
    # Center the matrices
    true_centered = true_corr - mean_true
    pred_centered = pred_corr - mean_pred
    
    # Compute correlation
    numerator = np.sum(true_centered * pred_centered)
    denominator = np.sqrt(np.sum(true_centered * true_centered) * 
                         np.sum(pred_centered * pred_centered))
    
    r = numerator / denominator
    return r * r

def naive_correlation_r2(X1, X2):
    """
    Original implementation using flattened matrices
    """
    true_corr_mtx = np.corrcoef(X1).flatten()
    pred_corr_mtx = np.corrcoef(X2).flatten()
    
    return pearsonr(true_corr_mtx, pred_corr_mtx)[0] ** 2

# Test both implementations
np.random.seed(42)
n, k = 2000, 10
U = np.random.randn(n, k)
V = np.random.randn(n, k)
V[:, :1] = U[:, :1]  # Make them partially similar

naive_result = naive_correlation_r2(U, V)
efficient_result = efficient_correlation_r2_pearson(U, V)

print(f"Naive R²: {naive_result:.6f}")
print(f"Efficient R²: {efficient_result:.6f}")

# Let's also print the actual correlation matrices to verify
print("\nFirst few elements of correlation matrices:")
print("True corr (naive):", np.corrcoef(U).flatten()[:5])
print("True corr (efficient):", np.corrcoef(U.T).flatten()[:5])

Naive R²: 0.008483
Efficient R²: 0.993551

First few elements of correlation matrices:
True corr (naive): [ 1.         -0.11051098 -0.31666654 -0.15239374  0.07219922]
True corr (efficient): [ 1.          0.02817255  0.0274025   0.01609295 -0.0151037 ]


In [60]:
def efficient_correlation_r2_pearson(U, V):
    """
    Compute squared Pearson correlation between correlation matrices 
    without forming full matrices, matching np.corrcoef exactly
    """
    # First compute the correlation matrices exactly as numpy does
    true_corr = np.corrcoef(U, rowvar=False)  # correlations between columns
    pred_corr = np.corrcoef(V, rowvar=False)
    
    # Now compute Pearson R² between these matrices without flattening
    mean_true = np.mean(true_corr)
    mean_pred = np.mean(pred_corr)
    
    # Center the matrices
    true_centered = true_corr - mean_true
    pred_centered = pred_corr - mean_pred
    
    # Compute correlation
    numerator = np.sum(true_centered * pred_centered)
    denominator = np.sqrt(np.sum(true_centered * true_centered) * 
                         np.sum(pred_centered * pred_centered))
    
    r = numerator / denominator
    return r * r

def naive_correlation_r2(X1, X2):
    """
    Original implementation using flattened matrices
    """
    true_corr_mtx = np.corrcoef(X1, rowvar=False).flatten()  # correlations between columns
    pred_corr_mtx = np.corrcoef(X2, rowvar=False).flatten()
    print(true_corr_mtx.shape, pred_corr_mtx.shape)
    
    return pearsonr(true_corr_mtx, pred_corr_mtx)[0] ** 2

# Test both implementations
np.random.seed(42)
n, k = 2000, 10
U = np.random.randn(n, k)
V = np.random.randn(n, k)
V[:, :1] = U[:, :1]  # Make them partially similar

naive_result = naive_correlation_r2(U, V)
efficient_result = efficient_correlation_r2_pearson(U, V)

print(f"Naive R²: {naive_result:.6f}")
print(f"Efficient R²: {efficient_result:.6f}")

# Let's also print the first few elements to verify they're the same
print("\nFirst few elements of correlation matrices:")
print("True corr (naive):", np.corrcoef(U, rowvar=False).flatten()[:5])
print("True corr (efficient):", np.corrcoef(U, rowvar=False).flatten()[:5])

(100,) (100,)
Naive R²: 0.993551
Efficient R²: 0.993551

First few elements of correlation matrices:
True corr (naive): [ 1.          0.02817255  0.0274025   0.01609295 -0.0151037 ]
True corr (efficient): [ 1.          0.02817255  0.0274025   0.01609295 -0.0151037 ]


In [1]:
import numpy as np
from scipy.stats import pearsonr
import time
import psutil
import os

def get_memory_usage():
    """Return memory usage in GB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024 / 1024

def efficient_correlation_r2_pearson(U, V):
    """Efficient version"""
    U_centered = U - U.mean(axis=1, keepdims=True)
    V_centered = V - V.mean(axis=1, keepdims=True)
    
    U_norms = np.sqrt(np.sum(U_centered * U_centered, axis=1, keepdims=True))
    V_norms = np.sqrt(np.sum(V_centered * V_centered, axis=1, keepdims=True))
    
    U_normalized = U_centered / U_norms
    V_normalized = V_centered / V_norms
    
    UTU = U_normalized @ U_normalized.T
    VTV = V_normalized @ V_normalized.T
    
    mean_UTU = np.mean(UTU)
    mean_VTV = np.mean(VTV)
    
    UTU_centered = UTU - mean_UTU
    VTV_centered = VTV - mean_VTV
    
    numerator = np.sum(UTU_centered * VTV_centered)
    denominator = np.sqrt(np.sum(UTU_centered * UTU_centered) * 
                         np.sum(VTV_centered * VTV_centered))
    
    return (numerator / denominator) ** 2

def naive_correlation_r2(X1, X2):
    """Naive version"""
    true_corr_mtx = np.corrcoef(X1, rowvar=True).flatten()
    pred_corr_mtx = np.corrcoef(X2, rowvar=True).flatten()
    return pearsonr(true_corr_mtx, pred_corr_mtx)[0] ** 2

# Test sizes (n_samples, keeping n_features=10 fixed)
test_sizes = [1000, 2000, 5000, 10000, 20000]
k = 10  # fixed number of features

results = []
initial_memory = get_memory_usage()

print("Matrix Size | Efficient Time (s) | Naive Time (s) | Speedup | Memory Usage (GB) | R² Match")
print("-" * 80)

for n in test_sizes:
    # Generate random matrices
    np.random.seed(42)
    U = np.random.randn(n, k)
    V = np.random.randn(n, k)
    V[:, :1] = U[:, :1]  # Make them partially similar
    
    # Time efficient version
    start_memory = get_memory_usage()
    start_time = time.time()
    eff_result = efficient_correlation_r2_pearson(U, V)
    eff_time = time.time() - start_time
    eff_memory = get_memory_usage() - start_memory
    
    # Time naive version
    start_memory = get_memory_usage()
    start_time = time.time()
    naive_result = naive_correlation_r2(U, V)
    naive_time = time.time() - start_time
    naive_memory = get_memory_usage() - start_memory
    
    # Calculate metrics
    speedup = naive_time / eff_time
    max_memory = max(eff_memory, naive_memory)
    results_match = np.abs(eff_result - naive_result) < 1e-10
    
    print(f"{n:>10} | {eff_time:>15.4f} | {naive_time:>13.4f} | {speedup:>7.2f}x | {max_memory:>15.4f} | {results_match}")
    
    # Clear some memory
    del U, V

print(f"\nPeak memory usage: {get_memory_usage() - initial_memory:.4f} GB")

Matrix Size | Efficient Time (s) | Naive Time (s) | Speedup | Memory Usage (GB) | R² Match
--------------------------------------------------------------------------------
      1000 |          0.0195 |        0.0287 |    1.48x |          0.0101 | True
      2000 |          0.0685 |        0.1144 |    1.67x |          0.0337 | True
      5000 |          0.4692 |        0.7709 |    1.64x |          0.0232 | True
     10000 |          1.9041 |        3.1527 |    1.66x |          0.0039 | True
     20000 |          8.4055 |       12.9496 |    1.54x |          0.0060 | True

Peak memory usage: 0.0500 GB


In [None]:
def efficient_correlation_r2_pearson_lowrank(U, V):
    """
    Compute squared Pearson correlation between correlation matrices 
    exploiting low rank structure (n x k where k << n)
    """
    # Center the data
    U_centered = U - U.mean(axis=1, keepdims=True)
    V_centered = V - V.mean(axis=1, keepdims=True)
    
    # Normalize rows
    U_norms = np.sqrt(np.sum(U_centered * U_centered, axis=1, keepdims=True))
    V_norms = np.sqrt(np.sum(V_centered * V_centered, axis=1, keepdims=True))
    
    U_normalized = U_centered / U_norms
    V_normalized = V_centered / V_norms
    
    # Instead of forming n×n matrices, work with n×k matrices
    # For the means, use the fact that tr(UU^T) = tr(U^TU)
    UTU = U_normalized.T @ U_normalized  # k×k matrix
    VTV = V_normalized.T @ V_normalized  # k×k matrix
    
    mean_UTU = np.sum(UTU) / (U.shape[0] ** 2)  # scalar
    mean_VTV = np.sum(VTV) / (V.shape[0] ** 2)  # scalar
    
    # For the correlation computation:
    # Instead of computing full UU^T * VV^T (n×n matrices)
    # Use the fact that tr((UU^T)(VV^T)) = tr((U^TV)(V^TU))
    UV = U_normalized.T @ V_normalized  # k×k matrix
    
    # Compute terms for numerator and denominator
    trace_UUTVVT = np.sum(UV * UV.T)  # scalar
    trace_UUTUUT = np.sum(UTU * UTU)  # scalar
    trace_VVTVVT = np.sum(VTV * VTV)  # scalar
    
    n_squared = U.shape[0] ** 2
    
    # Compute correlation
    numerator = trace_UUTVVT - n_squared * mean_UTU * mean_VTV
    denominator = np.sqrt(
        (trace_UUTUUT - n_squared * mean_UTU ** 2) * (trace_VVTVVT - n_squared * mean_VTV ** 2)
    )
    
    r = numerator / denominator
    return r * r

# Test and compare all implementations
np.random.seed(42)
test_sizes = [1000, 2000, 5000, 10000, 20000]
k = 10  # fixed number of features

print("Matrix Size | Efficient Low-Rank (s) | Original Efficient (s) | Naive (s) | Low-Rank Speedup")
print("-" * 90)

for n in test_sizes:
    # Generate random matrices
    U = np.random.randn(n, k)
    V = np.random.randn(n, k)
    V[:, :1] = U[:, :1]  # Make them partially similar
    
    # Time low-rank efficient version
    start_time = time.time()
    lowrank_result = efficient_correlation_r2_pearson_lowrank(U, V)
    lowrank_time = time.time() - start_time
    
    # Time original efficient version
    start_time = time.time()
    eff_result = efficient_correlation_r2_pearson(U, V)
    eff_time = time.time() - start_time
    
    # Time naive version
    start_time = time.time()
    naive_result = naive_correlation_r2(U, V)
    naive_time = time.time() - start_time
    
    # Calculate speedups
    speedup_vs_efficient = eff_time / lowrank_time
    speedup_vs_naive = naive_time / lowrank_time
    
    print(f"{n:>10} | {lowrank_time:>18.4f} | {eff_time:>19.4f} | {naive_time:>9.4f} | {speedup_vs_naive:>8.2f}x")
    
    # Verify results match
    assert np.abs(lowrank_result - naive_result) < 1e-3, "Results don't match!"

Matrix Size | Efficient Low-Rank (s) | Original Efficient (s) | Naive (s) | Low-Rank Speedup
------------------------------------------------------------------------------------------
      1000 |             0.0004 |              0.0129 |    0.0281 |    65.57x
      2000 |             0.0005 |              0.1083 |    0.2357 |   501.57x
      5000 |             0.0013 |              0.8128 |    1.5112 |  1142.44x
     10000 |             0.0084 |              2.3063 |    3.0373 |   361.15x
     20000 |             0.0048 |              8.4146 |   12.9380 |  2719.95x


In [4]:
lowrank_result, naive_result

(np.float64(0.006868941410695422), np.float64(0.006910486571052829))

In [None]:
def efficient_correlation_r2_pearson_different_ranks(U, V):
    """
    Compute squared Pearson correlation between correlation matrices 
    for matrices of different ranks
    U: (n_samples x k1) matrix
    V: (n_samples x k2) matrix
    """
    # Center the data
    U_centered = U - U.mean(axis=1, keepdims=True)
    V_centered = V - V.mean(axis=1, keepdims=True)
    
    # Normalize rows
    U_norms = np.sqrt(np.sum(U_centered * U_centered, axis=1, keepdims=True))
    V_norms = np.sqrt(np.sum(V_centered * V_centered, axis=1, keepdims=True))
    
    U_normalized = U_centered / U_norms
    V_normalized = V_centered / V_norms
    
    # Work with smaller matrices
    UTU = U_normalized.T @ U_normalized  # k1×k1 matrix
    VTV = V_normalized.T @ V_normalized  # k2×k2 matrix
    UV = U_normalized.T @ V_normalized   # k1×k2 matrix
    
    # Compute means efficiently
    mean_UTU = np.sum(UTU) / (U.shape[0] ** 2)
    mean_VTV = np.sum(VTV) / (V.shape[0] ** 2)
    
    # Use trace tricks with different sized matrices
    # For matrices of different sizes, tr(UU^T VV^T) = tr((U^TV)(V^TU)) = sum(UV * VU)
    # print(UV.shape, V_normalized.T.shape, U_normalized.shape)
    trace_UUTVVT = np.sum(UV.T * (V_normalized.T @ U_normalized))
    trace_UUTUUT = np.sum(UTU * UTU)
    trace_VVTVVT = np.sum(VTV * VTV)
    
    n_squared = U.shape[0] ** 2
    
    # Compute correlation
    numerator = trace_UUTVVT - n_squared * mean_UTU * mean_VTV
    denominator = np.sqrt((trace_UUTUUT - n_squared * mean_UTU ** 2) * 
                         (trace_VVTVVT - n_squared * mean_VTV ** 2))
    
    r = numerator / denominator
    return r * r

# Test with different ranks
np.random.seed(42)
test_sizes = [(1000, 5, 10), (2000, 10, 20), (5000, 15, 30)]  # (n_samples, k1, k2)

print("Matrix Size | Rank1 | Rank2 | Time (s) | Result")
print("-" * 60)

for n, k1, k2 in test_sizes:
    # Generate random matrices of different ranks
    U = np.random.randn(n, k1)
    V = np.random.randn(n, k2)
    
    # Make them partially similar in their overlapping dimensions
    min_k = min(k1, k2)
    V[:, :min_k] = U[:, :min_k]
    
    start_time = time.time()
    result = efficient_correlation_r2_pearson_different_ranks(U, V)
    computation_time = time.time() - start_time
    
    print(f"{n:>10} | {k1:>5} | {k2:>5} | {computation_time:>8.4f} | {result:.6f}")
    
    naive_result = naive_correlation_r2(U, V)
    print(f"Naive result: {naive_result:.6f}")
    assert np.abs(result - naive_result) < 1e-3, f"Results don't match! {result} vs {naive_result}"

Matrix Size | Rank1 | Rank2 | Time (s) | Result
------------------------------------------------------------
      1000 |     5 |    10 |   0.0004 | 0.372763
Naive result: 0.372764
      2000 |    10 |    20 |   0.0007 | 0.457513
Naive result: 0.457513
      5000 |    15 |    30 |   0.0549 | 0.469620


In [23]:
import numpy as np
from scipy.stats import pearsonr
import time
import psutil
import os

def get_memory_usage():
    """Return memory usage in GB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024 / 1024

# Test sizes (n_samples, k1, k2)
test_sizes = [
    (1000, 5, 10),
    (2000, 10, 20),
    (5000, 15, 30),
    (10000, 20, 40),
    (20000, 25, 50)
]

results = []
initial_memory = get_memory_usage()

print("Matrix Size | Ranks  | Low-Rank Time (s) | Naive Time (s) | Speedup | Memory (GB) | R² Match")
print("-" * 90)

for n, k1, k2 in test_sizes:
    # Generate random matrices
    np.random.seed(42)
    U = np.random.randn(n, k1)
    V = np.random.randn(n, k2)
    V[:, :min(k1, k2)] = U[:, :min(k1, k2)]  # Make them partially similar
    
    # Time low-rank version
    start_memory = get_memory_usage()
    start_time = time.time()
    lowrank_result = efficient_correlation_r2_pearson_different_ranks(U, V)
    lowrank_time = time.time() - start_time
    lowrank_memory = get_memory_usage() - start_memory
    
    # Time naive version
    start_memory = get_memory_usage()
    start_time = time.time()
    naive_result = naive_correlation_r2(U, V)
    naive_time = time.time() - start_time
    naive_memory = get_memory_usage() - start_memory
    
    # Calculate metrics
    speedup = naive_time / lowrank_time
    max_memory = max(lowrank_memory, naive_memory)
    results_match = np.abs(lowrank_result - naive_result) < 1e-10
    
    print(f"{n:>10} | {k1:>2},{k2:>2} | {lowrank_time:>15.4f} | {naive_time:>13.4f} | {speedup:>7.2f}x | {max_memory:>10.4f} | {results_match}")
    
    # Optional: print detailed results
    if not results_match:
        print(f"  Low-rank result: {lowrank_result:.6f}")
        print(f"  Naive result:    {naive_result:.6f}")
        print(f"  Difference:      {abs(lowrank_result - naive_result):.2e}")
    
    # Clear some memory
    del U, V

print(f"\nPeak memory usage: {get_memory_usage() - initial_memory:.4f} GB")

Matrix Size | Ranks  | Low-Rank Time (s) | Naive Time (s) | Speedup | Memory (GB) | R² Match
------------------------------------------------------------------------------------------
      1000 |  5,10 |          0.0004 |        0.1288 |  312.15x |     0.0444 | False
  Low-rank result: 0.372763
  Naive result:    0.372764
  Difference:      2.22e-07
      2000 | 10,20 |          0.0088 |        0.2531 |   28.76x |     0.0000 | False
  Low-rank result: 0.450282
  Naive result:    0.450281
  Difference:      1.14e-06
      5000 | 15,30 |          0.0362 |        1.1862 |   32.79x |     0.0232 | False
  Low-rank result: 0.468363
  Naive result:    0.468362
  Difference:      5.20e-07
     10000 | 20,40 |          0.0669 |        3.6951 |   55.21x |     0.0000 | False
  Low-rank result: 0.471352
  Naive result:    0.471352
  Difference:      1.28e-07
     20000 | 25,50 |          0.0160 |       13.2281 |  827.16x |     0.0065 | False
  Low-rank result: 0.477505
  Naive result:    0.477505

In [27]:
def efficient_covariance_r2_pearson_different_ranks(U, V):
    """
    Compute squared Pearson correlation between covariance matrices 
    for matrices of different ranks
    U: (n_samples x k1) matrix
    V: (n_samples x k2) matrix
    """
    # Center the data
    U_centered = U - U.mean(axis=1, keepdims=True)
    V_centered = V - V.mean(axis=1, keepdims=True)
    
    # For covariance, we don't normalize by row norms
    # Compute smaller matrices directly
    UTU = U_centered.T @ U_centered  # k1×k1 matrix
    VTV = V_centered.T @ V_centered  # k2×k2 matrix
    UV = U_centered.T @ V_centered   # k1×k2 matrix
    
    # Compute means efficiently
    mean_UTU = np.sum(UTU) / (U.shape[0] ** 2)
    mean_VTV = np.sum(VTV) / (V.shape[0] ** 2)
    
    # Use trace tricks with different sized matrices
    trace_UUTVVT = np.sum(UV.T * (V_centered.T @ U_centered))
    trace_UUTUUT = np.sum(UTU * UTU)
    trace_VVTVVT = np.sum(VTV * VTV)
    
    n_squared = U.shape[0] ** 2
    
    # Compute correlation
    numerator = trace_UUTVVT - n_squared * mean_UTU * mean_VTV
    denominator = np.sqrt((trace_UUTUUT - n_squared * mean_UTU ** 2) * 
                         (trace_VVTVVT - n_squared * mean_VTV ** 2))
    
    r = numerator / denominator
    return r * r

def naive_covariance_r2(X1, X2):
    """
    Original implementation using flattened matrices
    """
    true_cov_mtx = np.cov(X1, rowvar=True).flatten()
    pred_cov_mtx = np.cov(X2, rowvar=True).flatten()
    print(true_cov_mtx.shape, pred_cov_mtx.shape)
    
    return pearsonr(true_cov_mtx, pred_cov_mtx)[0] ** 2

# Benchmark both implementations
import numpy as np
from scipy.stats import pearsonr
import time
import psutil
import os

def get_memory_usage():
    """Return memory usage in GB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024 / 1024

# Test sizes (n_samples, k1, k2)
test_sizes = [
    (1000, 5, 10),
    (2000, 10, 20),
    (5000, 15, 30),
    (10000, 20, 40),
    (20000, 25, 50)
]

results = []
initial_memory = get_memory_usage()

print("Matrix Size | Ranks  | Low-Rank Time (s) | Naive Time (s) | Speedup | Memory (GB) | R² Match")
print("-" * 90)

for n, k1, k2 in test_sizes:
    # Generate random matrices
    np.random.seed(42)
    U = np.random.randn(n, k1)
    V = np.random.randn(n, k2)
    V[:, :min(k1, k2)] = U[:, :min(k1, k2)]  # Make them partially similar
    
    # Time low-rank version
    start_memory = get_memory_usage()
    start_time = time.time()
    lowrank_result = efficient_covariance_r2_pearson_different_ranks(U, V)
    lowrank_time = time.time() - start_time
    lowrank_memory = get_memory_usage() - start_memory
    
    # Time naive version
    start_memory = get_memory_usage()
    start_time = time.time()
    naive_result = naive_covariance_r2(U, V)
    naive_time = time.time() - start_time
    naive_memory = get_memory_usage() - start_memory
    
    # Calculate metrics
    speedup = naive_time / lowrank_time
    max_memory = max(lowrank_memory, naive_memory)
    results_match = np.abs(lowrank_result - naive_result) < 1e-10
    
    print(f"{n:>10} | {k1:>2},{k2:>2} | {lowrank_time:>15.4f} | {naive_time:>13.4f} | {speedup:>7.2f}x | {max_memory:>10.4f} | {results_match}")
    
    # Optional: print detailed results
    if not results_match:
        print(f"  Low-rank result: {lowrank_result:.6f}")
        print(f"  Naive result:    {naive_result:.6f}")
        print(f"  Difference:      {abs(lowrank_result - naive_result):.2e}")
    
    # Clear some memory
    del U, V

print(f"\nPeak memory usage: {get_memory_usage() - initial_memory:.4f} GB")

Matrix Size | Ranks  | Low-Rank Time (s) | Naive Time (s) | Speedup | Memory (GB) | R² Match
------------------------------------------------------------------------------------------
(1000000,) (1000000,)
      1000 |  5,10 |          0.0003 |        0.0572 |  177.00x |     0.0019 | False
  Low-rank result: 0.425432
  Naive result:    0.425432
  Difference:      2.76e-07
(4000000,) (4000000,)
      2000 | 10,20 |          0.0004 |        0.1078 |  253.86x |     0.0000 | False
  Low-rank result: 0.473235
  Naive result:    0.473234
  Difference:      8.52e-07
(25000000,) (25000000,)
      5000 | 15,30 |          0.0141 |        0.8345 |   59.02x |     0.0232 | False
  Low-rank result: 0.484887
  Naive result:    0.484887
  Difference:      4.63e-07
(100000000,) (100000000,)
     10000 | 20,40 |          0.4543 |        2.7394 |    6.03x |     0.0000 | False
  Low-rank result: 0.484533
  Naive result:    0.484533
  Difference:      1.25e-07
(400000000,) (400000000,)
     20000 | 25,50 |