In [1]:
import numpy as np
import random
from scipy.spatial.distance import pdist, squareform, cdist
from math import comb
import gc
import itertools
from dataprep import *
import scanpy as sc

Version that was able to run in 120 minutes. The test_permutation() took 6 seconds with multithreading and 30 seconds without

In [2]:
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.special import comb
from concurrent.futures import ThreadPoolExecutor, as_completed

def chamfer_L1_distance(distance_matrix, index_list):
    len_pattern = len(index_list) // 2
    distances_1_to_2 = np.min(distance_matrix[np.ix_(index_list[:len_pattern], index_list[len_pattern:])], axis=1)
    distances_2_to_1 = np.min(distance_matrix[np.ix_(index_list[len_pattern:], index_list[:len_pattern])], axis=1)
    return np.mean(distances_1_to_2) + np.mean(distances_2_to_1)

def test_permutation(pattern, control, n_permutations: int = 9999, return_distances: bool = False):
    ''' pattern and control are subsets of adata.obsm['latent'].
    '''
    combined = np.concatenate([pattern, control])
    distance_matrix = squareform(pdist(combined, metric='cityblock'))
    len_combined = len(combined)
    num_pattern = len(pattern)
    observed_statistic = chamfer_L1_distance(distance_matrix, list(range(len_combined)))

    if num_pattern < 15:
        total_permutations = comb(len_combined, num_pattern)
        if n_permutations > total_permutations:
            exact_test = True
            n_permutations = int(total_permutations)
        else:
            exact_test = False
    else:
        exact_test = False

    index_lists = np.apply_along_axis(np.random.permutation, 1, np.tile(list(range(len_combined)), (n_permutations, 1)))

    chamfer_distances = []
    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(chamfer_L1_distance, distance_matrix, index_list) for index_list in index_lists]
        for future in as_completed(futures):
            chamfer_distances.append(future.result())

    chamfer_distances = np.array(chamfer_distances)

    eps = (0 if not np.issubdtype(observed_statistic.dtype, np.inexact)
           else np.finfo(observed_statistic.dtype).eps * 100)
    gamma = np.abs(eps * observed_statistic)
    cmps_greater = chamfer_distances >= observed_statistic - gamma
    adjustment = 0 if exact_test else 1
    pvalues_greater = (cmps_greater.sum() + adjustment) / (n_permutations + adjustment)
    p_value = pvalues_greater

    if return_distances:
        return p_value, observed_statistic, chamfer_distances
    else:
        return p_value

def compute_power_permutation(params):
    try:
        strength, count, sample = params
        significant_count = 0
        bonferroni_count = 0
        # Given that random patterns have only 1442 cells simulated, we don't calculate the power for these above 1400 so that we don't need to sample with replacement
        if (count == '0-10' and sample > 1400) or (count == '10-30' and sample > 5100):
            return (f'{strength}_{count}_{sample}', -1)
    
        for i in range(1000):
            # sample new gene. No random seed so that every time a different "gene" is sampled.
            pattern = subset_power_analysis(adata_test, mixed_patterns = True, pattern_strength= strength, rna_count = count, sample_size = sample, random_seed=False)
            control = subset_power_analysis(adata_test, pattern = 'random', mixed_patterns = False, rna_count = count, sample_size = sample, random_seed=False)

            # Calculate null distribution and pvalue
            pvalue = test_permutation(pattern.obsm["latent"], control.obsm["latent"], n_permutations=9999, return_distances = False)
            
            critical_value = 0.05
            adjusted_critical_value = critical_value / 5000
            if pvalue < critical_value:
                significant_count += 1
            if pvalue < adjusted_critical_value:
                bonferroni_count += 1
        result = (f'{strength}_{count}_{sample}', significant_count/1000)
        bonferroni_result = (f'{strength}_{count}_{sample}', bonferroni_count/1000)

        return result, bonferroni_result
    except ValueError as e:
        if str(e) == "Cannot take a larger sample than population when 'replace=False'":
            print(f"Error for parameters: {params}")
            return None
        else:
            raise e
        
if __name__ == '__main__':
    # Set up logging
    logging.basicConfig(filename='permute.log', filemode='a', format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO)
    logging.info('Loading adata')

    adata_split_cellID = sc.read_h5ad("/media/gambino/students_workdir/nynke/new_model_with_cell_id_left_out_custom_nynke_panel_simulated_embeddings_adata.h5ad")
    adata_split_cellID = initialize_adata(adata_split_cellID)
    adata_test = adata_split_cellID[adata_split_cellID.obs['cell_id'].isin(adata_split_cellID.uns['test_cellIDs'])]

    strengths = ['strong', 'intermediate', 'low']
    counts = adata_test.obs['rna_count'].unique()
    #samples = [5, 9, 15, 27, 46, 81, 142, 247, 432, 753, 1315, 2297, 4009, 7000]
    samples = [1315, 2297, 4009, 7000]

    # Create a list of all combinations of strength, count, and sample
    combinations = [(strength, count, sample) for strength in strengths for count in counts for sample in samples]

    # Set the start method to 'spawn'
    mp.set_start_method('spawn', force=True)

    # Create a multiprocessing pool and compute the power for each combination
    with Pool(20) as p:
        results, bonferroni_result = p.map(compute_power_permutation, combinations)

    # in case want to do single processing:
    # results = [compute_power_permutation(combination) for combination in combinations]

    # Convert the results to a dictionary
    power_results = dict(results)
    bonferroni_power_results = dict(bonferroni_result)

    path = "temp_objects/power_analysis_permutationLatent_to7000_logscale_uncorrected.pkl"
    bonferroni_path = "temp_objects/power_analysis_permutationLatent_to7000_logscale_bonferroni.pkl"

    # Open the file in write-binary mode and dump the object
    with open(path, 'wb') as f:
        pickle.dump(power_results, f)

    # Open the file in write-binary mode and dump the object
    with open(bonferroni_path, 'wb') as f:
        pickle.dump(bonferroni_power_results, f)


In [None]:
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.special import comb
from concurrent.futures import ProcessPoolExecutor, as_completed
import logging
import pickle
import scanpy as sc
from multiprocessing import Pool, set_start_method

def chamfer_L1_distance(distance_matrix, index_list):
    len_pattern = len(index_list) // 2
    distances_1_to_2 = np.min(distance_matrix[np.ix_(index_list[:len_pattern], index_list[len_pattern:])], axis=1)
    distances_2_to_1 = np.min(distance_matrix[np.ix_(index_list[len_pattern:], index_list[:len_pattern])], axis=1)
    return np.mean(distances_1_to_2) + np.mean(distances_2_to_1)

def test_permutation(pattern, control, n_permutations: int = 9999, exact_test: bool = False, return_distances: bool = False):
    ''' pattern and control are subsets of adata.obsm['latent'].
    '''
    combined = np.concatenate([pattern, control])
    distance_matrix = squareform(pdist(combined, metric='cityblock'))
    len_combined = len(combined)
    num_pattern = len(pattern)
    observed_statistic = chamfer_L1_distance(distance_matrix, list(range(len_combined)))

    index_lists = np.apply_along_axis(np.random.permutation, 1, np.tile(list(range(len_combined)), (n_permutations, 1)))

    chamfer_distances = []
    with ProcessPoolExecutor() as executor:
        futures = [executor.submit(chamfer_L1_distance, distance_matrix, index_list) for index_list in index_lists]
        for future in as_completed(futures):
            chamfer_distances.append(future.result())

    chamfer_distances = np.array(chamfer_distances)

    eps = (0 if not np.issubdtype(observed_statistic.dtype, np.inexact)
           else np.finfo(observed_statistic.dtype).eps * 100)
    gamma = np.abs(eps * observed_statistic)
    cmps_greater = chamfer_distances >= observed_statistic - gamma
    adjustment = 0 if exact_test else 1
    pvalues_greater = (cmps_greater.sum() + adjustment) / (n_permutations + adjustment)
    p_value = pvalues_greater

    if return_distances:
        return p_value, observed_statistic, chamfer_distances
    else:
        return p_value

def compute_power_permutation(params):
    strength, count, sample = params
    significant_count = 0
    bonferroni_count = 0

    # Given that random patterns have only 1442 cells simulated, we don't calculate the power for these above 1400 so that we don't need to sample with replacement
    if (count == '0-10' and sample > 1400) or (count == '10-30' and sample > 5100):
        with open('/media/gambino/students_workdir/nynke/blurry/results.txt', 'a') as f:
            f.write(f'{strength}\t{count}\t{sample}\t{significant_count / 1000}\t{bonferroni_count / 1000}\n')
        return (f'{strength}_{count}_{sample}', -1)
    
    # Count max number of permutations with Combination rule nCr, where r is the pattern size
    if sample < 15:
        total_permutations = comb(sample*2, sample) # built in implementation of nCr rule.
        # Adjust n_permutations if it's larger than total_permutations
        if n_permutations > total_permutations:
            exact_test = True
            n_permutations = int(total_permutations)
        else:
            n_permutations = 9999
            exact_test = False
    else:
        # If num_pattern is 15, the total combinations are 1.5e8, which already is much larger than 9999. So we skip calculating the factorials for 15+ to save compute time. 
        n_permutations = 9999
        exact_test = False

    try:
        pvalues = np.zeros(1000)
        for i in range(1000):
            # sample new gene. No random seed so that every time a different "gene" is sampled.
            pattern = subset_power_analysis(adata_test, mixed_patterns = True, pattern_strength= strength, rna_count = count, sample_size = sample, random_seed=False)
            control = subset_power_analysis(adata_test, pattern = 'random', mixed_patterns = False, rna_count = count, sample_size = sample, random_seed=False)

            # Calculate null distribution and pvalue
            pvalues[i] = test_permutation(pattern.obsm["latent"], control.obsm["latent"], n_permutations=9999, exact_test = exact_test, return_distances = False)
            
        critical_value = 0.05
        adjusted_critical_value = critical_value / 5000
        significant_count = np.sum(pvalues < critical_value)
        bonferroni_count = np.sum(pvalues < adjusted_critical_value)

        with open('/media/gambino/students_workdir/nynke/blurry/results.txt', 'a') as f:
            f.write(f'{strength}\t{count}\t{sample}\t{significant_count / 1000}\t{bonferroni_count / 1000}\n')

        result = (f'{strength}_{count}_{sample}', significant_count/1000)
        bonferroni_result = (f'{strength}_{count}_{sample}', bonferroni_count/1000)

        return result, bonferroni_result
    except ValueError as e:
        if str(e) == "Cannot take a larger sample than population when 'replace=False'":
            print(f"Error for parameters: {params}")
            return None
        else:
            raise e

if __name__ == '__main__':
    # Set up logging
    logging.basicConfig(filename='permute.log', filemode='a', format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO)
    
    logging.info('Loading adata')
    adata_split_cellID = sc.read_h5ad("/media/gambino/students_workdir/nynke/new_model_with_cell_id_left_out_custom_nynke_panel_simulated_embeddings_adata.h5ad")
    adata_split_cellID = initialize_adata(adata_split_cellID)
    adata_test = adata_split_cellID[adata_split_cellID.obs['cell_id'].isin(adata_split_cellID.uns['test_cellIDs'])]

    strengths = ['strong', 'intermediate', 'low']
    counts = adata_test.obs['rna_count'].unique()
    #samples = [5, 9, 15, 27, 46, 81, 142, 247, 432, 753, 1315, 2297, 4009, 7000]
    samples = [1315, 2297, 4009, 7000]

    # Create a list of all combinations of strength, count, and sample
    combinations = [(strength, count, sample) for strength in strengths for count in counts for sample in samples]

    # Set the start method to 'spawn'
    set_start_method('spawn', force=True)

    # Create a multiprocessing pool and compute the power for each combination
    with Pool(20) as p:
        results = p.map(compute_power_permutation, combinations)

    # Convert the results to a dictionary
    power_results = {result[0]: result[1] for result, _ in results if result}
    bonferroni_power_results = {bonferroni_result[0]: bonferroni_result[1] for _, bonferroni_result in results if bonferroni_result}

    path = "temp_objects/power_analysis_permutationLatent_to7000_logscale_uncorrected.pkl"
    bonferroni_path = "temp_objects/power_analysis_permutationLatent_to7000_logscale_bonferroni.pkl"

    # Open the file in write-binary mode and dump the object
    with open(path, 'wb') as f:
        pickle.dump(power_results, f)

    # Open the file in write-binary mode and dump the object
    with open(bonferroni_path, 'wb') as f:
        pickle.dump(bonferroni_power_results, f)

In [3]:
def chamfer_L1_distance_batch(distance_matrix, index_lists):
    len_pattern = index_lists.shape[1] // 2
    distances_1_to_2 = np.min(distance_matrix[index_lists[:, :len_pattern], :][:, :, index_lists[:, len_pattern:]], axis=2)
    distances_2_to_1 = np.min(distance_matrix[index_lists[:, len_pattern:], :][:, :, index_lists[:, :len_pattern]], axis=2)
    return np.mean(distances_1_to_2, axis=1) + np.mean(distances_2_to_1, axis=1)

def chamfer_L1_distance_batch(distance_matrix, index_lists):
    """
    Computes the Chamfer L1 distances for a batch of permutations.

    Args:
        distance_matrix (ndarray): A 2D numpy array representing the pairwise distance matrix.
        index_lists (ndarray): A 2D numpy array where each row represents a permutation of indices.

    Returns:
        ndarray: A 1D array of Chamfer L1 distances for each permutation.
    """
    len_combined = distance_matrix.shape[0]
    len_pattern = index_lists.shape[1] // 2

    # Create a 3D index array for pattern to control distances
    idx1 = np.arange(index_lists.shape[0])[:, None, None]
    idx2 = index_lists[:, :len_pattern][:, :, None]
    idx3 = index_lists[:, len_pattern:][:, None, :]
    
    distances_1_to_2 = np.min(distance_matrix[idx1, idx2, idx3], axis=2)
    distances_2_to_1 = np.min(distance_matrix[idx1, idx3, idx2], axis=2)
    
    return np.mean(distances_1_to_2, axis=1) + np.mean(distances_2_to_1, axis=1)

def permutation_test_batch(patterns, controls, n_permutations=9999):
    combined = np.concatenate((patterns, controls), axis=1)
    num_patterns = patterns.shape[1]
    len_combined = combined.shape[1]

    observed_statistics = []
    for pattern, control in zip(patterns, controls):
        combined_pc = np.concatenate((pattern, control))
        distance_matrix = squareform(pdist(combined_pc, metric='cityblock'))
        observed_statistic = chamfer_L1_distance(distance_matrix, np.arange(len_combined))
        observed_statistics.append(observed_statistic)
    
    observed_statistics = np.array(observed_statistics)

    index_lists = np.array([np.random.permutation(len_combined) for _ in range(n_permutations)])
    chamfer_distances = []
    for combined_pc in combined:
        distance_matrix = squareform(pdist(combined_pc, metric='cityblock'))
        chamfer_distances.append(chamfer_L1_distance_batch(distance_matrix, index_lists))

    chamfer_distances = np.array(chamfer_distances)

    eps = np.finfo(observed_statistics.dtype).eps * 100
    gamma = np.abs(eps * observed_statistics)
    cmps_greater = chamfer_distances >= observed_statistics[:, None] - gamma[:, None]

    pvalues_greater = (cmps_greater.sum(axis=1) + 1) / (n_permutations + 1)
    return pvalues_greater


In [4]:
def compute_power_permutation_batch(params):
    logging.info('Starting compute_power_permutation with params: %s', params)
    try:
        strength, count, sample = params
        significant_count = 0
        bonferroni_count = 0
        if (count == '0-10' and sample > 1400) or (count == '10-30' and sample > 5100):
            return (f'{strength}_{count}_{sample}', -1)
        
        patterns = []
        controls = []

        for i in range(1000):
            pattern = subset_power_analysis(adata_test, mixed_patterns=True, pattern_strength=strength, rna_count=count, sample_size=sample, random_seed=False)
            control = subset_power_analysis(adata_test, pattern='random', mixed_patterns=False, rna_count=count, sample_size=sample, random_seed=False)
            patterns.append(pattern.obsm["latent"])
            controls.append(control.obsm["latent"])

        patterns = np.array(patterns)
        controls = np.array(controls)

        pvalues = permutation_test_batch(patterns, controls, n_permutations=9999)

        critical_value = 0.05
        adjusted_critical_value = critical_value / 5000
        significant_count = np.sum(pvalues < critical_value)
        bonferroni_count = np.sum(pvalues < adjusted_critical_value)

        result = (f'{strength}_{count}_{sample}', significant_count / 1000)
        bonferroni_result = (f'{strength}_{count}_{sample}', bonferroni_count / 1000)
        
        with open('/media/gambino/students_workdir/nynke/blurry/results.txt', 'a') as f:
            f.write(f'{strength}\t{count}\t{sample}\t{significant_count / 1000}\t{bonferroni_count / 1000}\n')
        
        logging.info('Finished compute_power_permutation with params: %s', params)
        return result, bonferroni_result

    except ValueError as e:
        if str(e) == "Cannot take a larger sample than population when 'replace=False'":
            print(f"Error for parameters: {params}")
            return None
        else:
            raise e


In [2]:
adata_split_cellID = sc.read_h5ad("/media/gambino/students_workdir/nynke/new_model_with_cell_id_left_out_custom_nynke_panel_simulated_embeddings_adata.h5ad")
adata_split_cellID = initialize_adata(adata_split_cellID)
adata_test = adata_split_cellID[adata_split_cellID.obs['cell_id'].isin(adata_split_cellID.uns['test_cellIDs'])]
combination = ('low', '10-30', 753)
strength, count, sample = combination

In [3]:
adata_test

View of AnnData object with n_obs × n_vars = 329349 × 15
    obs: 'pattern', 'random_or_pattern', 'n_spots', 'n_spots_interval', 'cell_id', 'genes', 'rotation', 'rotation_interval', 'blur', 'prop', 'prop_interval', 'corresponding_dapis', 'train_or_val', 'original_image_paths', 'pattern_strength', 'rna_count'
    uns: 'test_cellIDs', 'train_cellIDs'
    obsm: 'latent'

In [1]:
strengths = ['strong', 'intermediate', 'low']
counts = ['0-10', '10-30', '30-60', '60-100','100+']
#samples = [5, 9, 15, 27, 46, 81, 142, 247, 432, 753, 1315, 2297, 4009, 7000]
samples = [1315, 2297, 4009] # , 7000]

# Create a list of all combinations of strength, count, and sample
combinations = [(strength, count, sample) for strength in strengths for count in counts for sample in samples]
print(combinations)
combinations.append(('low', '0-10', 753))
print(combinations)

[('strong', '0-10', 1315), ('strong', '0-10', 2297), ('strong', '0-10', 4009), ('strong', '10-30', 1315), ('strong', '10-30', 2297), ('strong', '10-30', 4009), ('strong', '30-60', 1315), ('strong', '30-60', 2297), ('strong', '30-60', 4009), ('strong', '60-100', 1315), ('strong', '60-100', 2297), ('strong', '60-100', 4009), ('strong', '100+', 1315), ('strong', '100+', 2297), ('strong', '100+', 4009), ('intermediate', '0-10', 1315), ('intermediate', '0-10', 2297), ('intermediate', '0-10', 4009), ('intermediate', '10-30', 1315), ('intermediate', '10-30', 2297), ('intermediate', '10-30', 4009), ('intermediate', '30-60', 1315), ('intermediate', '30-60', 2297), ('intermediate', '30-60', 4009), ('intermediate', '60-100', 1315), ('intermediate', '60-100', 2297), ('intermediate', '60-100', 4009), ('intermediate', '100+', 1315), ('intermediate', '100+', 2297), ('intermediate', '100+', 4009), ('low', '0-10', 1315), ('low', '0-10', 2297), ('low', '0-10', 4009), ('low', '10-30', 1315), ('low', '10-

In [5]:
import concurrent.futures
import numpy as np
from multiprocessing import Pool, Manager
import logging
import multiprocessing as mp
import random
from scipy.spatial.distance import pdist, squareform
from math import comb
from dataprep import *
import pickle
import scanpy as sc


def compute_power_permutation(params, adata_test):
    logging.info('Starting compute_power_permutation with params: %s', params)
    try:
        strength, count, sample = params
        significant_count = 0
        bonferroni_count = 0
        if (count == '0-10' and sample > 1400) or (count == '10-30' and sample > 5100):
            return (f'{strength}_{count}_{sample}', -1)
        
        # Generate 1000 samples of pattern and control
        patterns = np.array([subset_power_analysis(adata_test, mixed_patterns=True, pattern_strength=strength, rna_count=count, sample_size=sample, random_seed=False).obsm["latent"] for _ in range(1000)])
        controls = np.array([subset_power_analysis(adata_test, pattern='random', mixed_patterns=False, rna_count=count, sample_size=sample, random_seed=False).obsm["latent"] for _ in range(1000)])

        # Count max number of permutations with Combination rule nCr, where r is the pattern size
        if sample < 15:
            total_permutations = comb(sample*2, sample) # built in implementation of nCr rule.
            # Adjust n_permutations if it's larger than total_permutations
            if n_permutations > total_permutations:
                exact_test = True
                n_permutations = int(total_permutations)
            else:
                n_permutations = 9999
                exact_test = False
        else:
            # If num_pattern is 15, the total combinations are 1.5e8, which already is much larger than 9999. So we skip calculating the factorials for 15+ to save compute time. 
            n_permutations = 9999
            exact_test = False

        pvalues = permutation_test_batch(patterns, controls, n_permutations=n_permutations, exact_test=exact_test)

        critical_value = 0.05
        adjusted_critical_value = critical_value / 5000
        significant_count = np.sum(pvalues < critical_value)
        bonferroni_count = np.sum(pvalues < adjusted_critical_value)

        result = (f'{strength}_{count}_{sample}', significant_count / 1000)
        bonferroni_result = (f'{strength}_{count}_{sample}', bonferroni_count / 1000)
        
        with open('/media/gambino/students_workdir/nynke/blurry/results_permutation.txt', 'a') as f:
            f.write(f'{strength}\t{count}\t{sample}\t{significant_count / 1000}\t{bonferroni_count / 1000}\n')
        
        logging.info('Finished compute_power_permutation with params: %s', params)
        return result, bonferroni_result
    except ValueError as e:
        if str(e) == "Cannot take a larger sample than population when 'replace=False'":
            print(f"Error for parameters: {params}")
            return None
        else:
            raise e


def chamfer_L1_distance(distance_matrix, index_list):
    len_pattern = len(index_list) // 2
    distances_1_to_2 = np.min(distance_matrix[np.ix_(index_list[:len_pattern], index_list[len_pattern:])], axis=1)
    distances_2_to_1 = np.min(distance_matrix[np.ix_(index_list[len_pattern:], index_list[:len_pattern])], axis=1)
    return np.mean(distances_1_to_2) + np.mean(distances_2_to_1)

def chamfer_L1_distance_batch(distance_matrices, index_list):
    """
    Computes the Chamfer L1 distances for a batch of distance matrices
    Args:
        distance_matrix (ndarray): A 3D numpy array representing the pairwise distance matrix for many samples.
        index_lists (ndarray): A 1D numpy array represents a permutation of indices.
    Returns:
        ndarray: A 1D array of Chamfer L1 distances for all samples.
    """
    len_pattern = index_list.shape[0] // 2
    n_power_iterations=distance_matrices.shape[0]
    # Create a 3D index array for pattern to control distances
    idx2 = index_list[:len_pattern]
    idx3 = index_list[len_pattern:]
    distances_1_to_2 = np.min(distance_matrices[np.ix_(np.arange(n_power_iterations), idx2, idx3)], axis=2)
    distances_2_to_1 = np.min(distance_matrices[np.ix_(np.arange(n_power_iterations), idx3, idx2)], axis=2)
    return np.mean(distances_1_to_2, axis=1) + np.mean(distances_2_to_1, axis=1)

def permutation_test_batch(patterns, controls, n_permutations=9999, exact_test=False):
    combined = np.concatenate((patterns, controls), axis=1)
    len_combined = combined.shape[1]
    indices_list = np.arange(len_combined)

    distance_matrices = np.array([squareform(pdist(combined[i], metric='cityblock')) for i in range(1000)])

    # Compute observed statistics for all samples
    observed_statistics = np.array([chamfer_L1_distance_batch(distance_matrices, indices_list)])
    
    # Generate index permutations
    index_lists = np.array([np.random.permutation(len_combined) for _ in range(n_permutations)])

    # Compute Chamfer distances for all permutations and samples using multithreading
    with concurrent.futures.ThreadPoolExecutor() as executor:
        chamfer_distances = np.array(list(executor.map(lambda index_list: chamfer_L1_distance_batch(distance_matrices, index_list), index_lists)))

    # Compute Chamfer distances for all permutations and samples
    #chamfer_distances = np.array([chamfer_L1_distance_batch(distance_matrices, index_list) for index_list in index_lists])


    # Calculate p-values (one-sided test cause only interested in larger distances than H0)
    eps = np.finfo(observed_statistics.dtype).eps * 100
    gamma = np.abs(eps * observed_statistics)
    cmps_greater = chamfer_distances >= observed_statistics[:, None] - gamma[:, None]
    adjustment = 0 if exact_test else 1
    pvalues_greater = (cmps_greater.sum(axis=1) + adjustment) / (n_permutations + adjustment)
    return pvalues_greater

if __name__ == '__main__':
    # Set up logging
    logging.basicConfig(filename='permute.log', filemode='a', format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO)
    logging.info('Loading adata')

    adata_split_cellID = sc.read_h5ad("/media/gambino/students_workdir/nynke/new_model_with_cell_id_left_out_custom_nynke_panel_simulated_embeddings_adata.h5ad")
    adata_split_cellID = initialize_adata(adata_split_cellID)
    adata_test = adata_split_cellID[adata_split_cellID.obs['cell_id'].isin(adata_split_cellID.uns['test_cellIDs'])]
    params=('low', '10-30', 753)
    logging.info('Adata loaded')
    result, bonferroni_result = compute_power_permutation(params, adata_test)
    logging.info('Result: %s', result)
    print("hooray we're done!")
    '''

    strengths = ['strong', 'intermediate', 'low']
    counts = adata_test.obs['rna_count'].unique()
    #samples = [5, 9, 15, 27, 46, 81, 142, 247, 432, 753, 1315, 2297, 4009, 7000]
    samples = [1315, 2297, 4009, 7000]

    # Create a list of all combinations of strength, count, and sample
    combinations = [(strength, count, sample) for strength in strengths for count in counts for sample in samples]

    # Set the start method to 'spawn'
    mp.set_start_method('spawn', force=True)


    # Create a multiprocessing pool and compute the power for each combination
    with Manager() as manager:
        shared_adata_test = manager.list([adata_test])
        with Pool(20) as p:
            results, bonferroni_result = p.starmap(compute_power_permutation, [(params, shared_adata_test) for params in combinations])

    # in case want to do single processing:
    # results = [compute_power_permutation(combination) for combination in combinations]

    # Convert the results to a dictionary
    power_results = dict(results)
    bonferroni_power_results = dict(bonferroni_result)

    path = "temp_objects/power_analysis_permutationLatent_to7000_logscale_uncorrected.pkl"
    bonferroni_path = "temp_objects/power_analysis_permutationLatent_to7000_logscale_bonferroni.pkl"

    # Open the file in write-binary mode and dump the object
    with open(path, 'wb') as f:
        pickle.dump(power_results, f)

    # Open the file in write-binary mode and dump the object
    with open(bonferroni_path, 'wb') as f:
        pickle.dump(bonferroni_power_results, f)'''


In [None]:
if __name__ == '__main__':
    # Set up logging
    logging.basicConfig(filename='permute.log', filemode='a', format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO)
    logging.info('Loading adata')

    adata_split_cellID = sc.read_h5ad("/media/gambino/students_workdir/nynke/new_model_with_cell_id_left_out_custom_nynke_panel_simulated_embeddings_adata.h5ad")
    adata_split_cellID = initialize_adata(adata_split_cellID)
    adata_test = adata_split_cellID[adata_split_cellID.obs['cell_id'].isin(adata_split_cellID.uns['test_cellIDs'])]

    strengths = ['strong', 'intermediate', 'low']
    counts = adata_test.obs['rna_count'].unique()
    #samples = [5, 9, 15, 27, 46, 81, 142, 247, 432, 753, 1315, 2297, 4009, 7000]
    samples = [1315, 2297, 4009, 7000]

    # Create a list of all combinations of strength, count, and sample
    combinations = [(strength, count, sample) for strength in strengths for count in counts for sample in samples]

    # Set the start method to 'spawn'
    mp.set_start_method('spawn', force=True)

    # Create a multiprocessing pool and compute the power for each combination
    with Pool(20) as p:
        results, bonferroni_result = p.map(compute_power_permutation, combinations)

    # in case want to do single processing:
    # results = [compute_power_permutation(combination) for combination in combinations]

    # Convert the results to a dictionary
    power_results = dict(results)
    bonferroni_power_results = dict(bonferroni_result)

    path = "temp_objects/power_analysis_permutationLatent_to7000_logscale_uncorrected.pkl"
    bonferroni_path = "temp_objects/power_analysis_permutationLatent_to7000_logscale_bonferroni.pkl"

    # Open the file in write-binary mode and dump the object
    with open(path, 'wb') as f:
        pickle.dump(power_results, f)

    # Open the file in write-binary mode and dump the object
    with open(bonferroni_path, 'wb') as f:
        pickle.dump(bonferroni_power_results, f)


In [8]:
patterns = np.array([subset_power_analysis(adata_test, mixed_patterns=True, pattern_strength=strength, rna_count=count, sample_size=sample, random_seed=False).obsm["latent"] for _ in range(10)])
controls = np.array([subset_power_analysis(adata_test, pattern='random', mixed_patterns=False, rna_count=count, sample_size=sample, random_seed=False).obsm["latent"] for _ in range(10)])
        

In [6]:
patterns = []
controls = []

for i in range(10):
    pattern = subset_power_analysis(adata_test, mixed_patterns=True, pattern_strength=strength, rna_count=count, sample_size=sample, random_seed=False)
    control = subset_power_analysis(adata_test, pattern='random', mixed_patterns=False, rna_count=count, sample_size=sample, random_seed=False)
    patterns.append(pattern.obsm["latent"])
    controls.append(control.obsm["latent"])

patterns = np.array(patterns)
controls = np.array(controls)

In [30]:
patterns.shape, controls.shape

((10, 753, 15), (10, 753, 15))

In [9]:
combined = np.concatenate((patterns, controls), axis=1)
len_combined = combined.shape[1]
indices_list = np.arange(len_combined)

((10, 1506, 15), (1506,))

In [10]:
distance_matrices = np.array([squareform(pdist(combined[i], metric='cityblock')) for i in range(10)])
distance_matrices.shape

(10, 1506, 1506)

In [13]:
observed_statistics = np.array([chamfer_L1_distance(distance_matrix, indices_list) for distance_matrix in distance_matrices])
observed_statistics

array([25.4700732 , 25.48903903, 25.53854045, 25.51355722, 25.31809289,
       25.44363946, 25.56295549, 25.67546666, 25.58864391, 25.60535183])

In [14]:
index_lists = np.array([np.random.permutation(len_combined) for _ in range(9999)])
index_lists.shape

(9999, 1506)

In [19]:
distance_matrices[0].shape

(1506, 1506)

In [20]:
len_combined = distance_matrices[0].shape[0]
len_pattern = index_lists.shape[1] // 2
len_pattern

753

In [21]:
idx1 = np.arange(index_lists.shape[0])[:, None, None]
idx2 = index_lists[:, :len_pattern][:, :, None]
idx3 = index_lists[:, len_pattern:][:, None, :]
idx1.shape, idx2.shape, idx3.shape

((9999, 1, 1), (9999, 753, 1), (9999, 1, 753))

In [22]:
distances_1_to_2 = np.min(distance_matrices[0][idx1, idx2, idx3], axis=2)
distances_1_to_2.shape

IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed

In [24]:
len_pattern = index_lists.shape[1] // 2
indices_pattern = index_lists[:, :len_pattern]
indices_control = index_lists[:, len_pattern:]

In [25]:

indices_pattern.shape, indices_control.shape

((9999, 753), (9999, 753))

In [27]:
distance_matrix = distance_matrices[0]
distance_matrix.shape, distance_matrices.shape

((1506, 1506), (10, 1506, 1506))

In [29]:
distance_matrix[np.newaxis, :, :][:, indices_pattern, :][:, :, indices_control].shape

MemoryError: Unable to allocate 825. TiB for an array with shape (9999, 753, 1, 9999, 1506) and data type float64

ChatGPT suggestions for the batch calculation of chamfer distance

In [None]:
def chamfer_L1_distance_batch(distance_matrix, index_lists):
    """
    Computes the Chamfer L1 distances for a batch of permutations.

    Args:
        distance_matrix (ndarray): A 2D numpy array representing the pairwise distance matrix.
        index_lists (ndarray): A 2D numpy array where each row represents a permutation of indices.

    Returns:
        ndarray: A 1D array of Chamfer L1 distances for each permutation.
    """
    len_pattern = index_lists.shape[1] // 2
    indices_pattern = index_lists[:, :len_pattern]
    indices_control = index_lists[:, len_pattern:]

    # Broadcasting the distance matrix to match the dimensions required for batch processing
    distances_1_to_2 = np.min(distance_matrix[np.newaxis, :, :][:, indices_pattern, :][:, :, indices_control], axis=2)
    distances_2_to_1 = np.min(distance_matrix[np.newaxis, :, :][:, indices_control, :][:, :, indices_pattern], axis=2)

    return np.mean(distances_1_to_2, axis=1) + np.mean(distances_2_to_1, axis=1)

In [None]:
def chamfer_L1_distance_batch(distance_matrix, index_lists):
    """
    Computes the Chamfer L1 distances for a batch of permutations.

    Args:
        distance_matrix (ndarray): A 2D numpy array representing the pairwise distance matrix.
        index_lists (ndarray): A 2D numpy array where each row represents a permutation of indices.

    Returns:
        ndarray: A 1D array of Chamfer L1 distances for each permutation.
    """
    len_combined = distance_matrix.shape[0]
    len_pattern = index_lists.shape[1] // 2

    # Create a 3D index array for pattern to control distances
    idx1 = np.arange(index_lists.shape[0])[:, None, None]
    idx2 = index_lists[:, :len_pattern][:, :, None]
    idx3 = index_lists[:, len_pattern:][:, None, :]
    
    distances_1_to_2 = np.min(distance_matrix[idx1, idx2, idx3], axis=2)
    distances_2_to_1 = np.min(distance_matrix[idx1, idx3, idx2], axis=2)
    
    return np.mean(distances_1_to_2, axis=1) + np.mean(distances_2_to_1, axis=1)


In [None]:
def chamfer_L1_distance(distance_matrix, index_list):
    len_pattern = len(index_list) // 2
    distances_1_to_2 = np.min(distance_matrix[np.ix_(index_list[:len_pattern], index_list[len_pattern:])], axis=1)
    distances_2_to_1 = np.min(distance_matrix[np.ix_(index_list[len_pattern:], index_list[:len_pattern])], axis=1)
    return np.mean(distances_1_to_2) + np.mean(distances_2_to_1)

In [38]:
distance_matrices = np.zeros((10, 100, 100))
x = np.arange(100)
n_permutations = 9
len_pattern = 50
rng = np.random.default_rng(1234)
indices_lists = rng.permuted(np.tile(x, n_permutations).reshape(n_permutations, x.size), axis=1)
for index_list in indices_lists:
    print(distance_matrices[np.ix_(np.arange(10), index_list[:len_pattern], index_list[len_pattern:])].shape)

(10, 50, 50)
(10, 50, 50)
(10, 50, 50)
(10, 50, 50)
(10, 50, 50)
(10, 50, 50)
(10, 50, 50)
(10, 50, 50)
(10, 50, 50)


In [51]:
indices_lists[0].shape[0]

100

In [52]:
def chamfer_L1_distance_batch(distance_matrices, index_list):
    """
    Computes the Chamfer L1 distances for a batch of distance matrices
    Args:
        distance_matrix (ndarray): A 3D numpy array representing the pairwise distance matrix for many samples.
        index_lists (ndarray): A 1D numpy array represents a permutation of indices.
    Returns:
        ndarray: A 1D array of Chamfer L1 distances for all samples.
    """
    len_pattern = index_list.shape[0] // 2
    # Create a 3D index array for pattern to control distances
    idx2 = index_list[:len_pattern]
    idx3 = index_list[len_pattern:]
    distances_1_to_2 = np.min(distance_matrices[np.ix_(np.arange(10), index_list[:len_pattern], index_list[len_pattern:])], axis=2)
    distances_2_to_1 = np.min(distance_matrices[np.ix_(np.arange(10), index_list[:len_pattern], index_list[len_pattern:])], axis=2)
    return np.mean(distances_1_to_2, axis=1) + np.mean(distances_2_to_1, axis=1)
distance_matrices = np.zeros((10, 100, 100))
x = np.arange(100)
n_permutations = 9
len_pattern = 50
rng = np.random.default_rng(1234)
indices_lists = rng.permuted(np.tile(x, n_permutations).reshape(n_permutations, x.size), axis=1)
for index_list in indices_lists:
    print(distance_matrices[np.ix_(np.arange(10), index_list[:len_pattern], index_list[len_pattern:])].shape)
    print(chamfer_L1_distance_batch(distance_matrices, index_list).shape)

(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)


In [53]:

distance_matrices = np.zeros((10, 100, 100))
x = np.arange(100)
n_permutations = 9
len_pattern = 50
rng = np.random.default_rng(1234)
indices_lists = rng.permuted(np.tile(x, n_permutations).reshape(n_permutations, x.size), axis=1)
for index_list in indices_lists:
    print(distance_matrices[np.ix_(np.arange(10), index_list[:len_pattern], index_list[len_pattern:])].shape)
    print(chamfer_L1_distance_batch(distance_matrices, index_list).shape)

(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)
(10, 50, 50)
(10,)


Newest version

In [60]:
import concurrent.futures

def compute_power_permutation_test(params):
    logging.info('Starting compute_power_permutation with params: %s', params)
    try:
        strength, count, sample = params
        significant_count = 0
        bonferroni_count = 0
        if (count == '0-10' and sample > 1400) or (count == '10-30' and sample > 5100):
            return (f'{strength}_{count}_{sample}', -1)
        
        # Generate 1000 samples of pattern and control
        patterns = np.array([subset_power_analysis(adata_test, mixed_patterns=True, pattern_strength=strength, rna_count=count, sample_size=sample, random_seed=False).obsm["latent"] for _ in range(1000)])
        controls = np.array([subset_power_analysis(adata_test, pattern='random', mixed_patterns=False, rna_count=count, sample_size=sample, random_seed=False).obsm["latent"] for _ in range(1000)])
        logging.info('Generated the 1000 samples of pattern and control')

        # Count max number of permutations with Combination rule nCr, where r is the pattern size
        if sample < 15:
            total_permutations = comb(sample*2, sample) # built in implementation of nCr rule.
            # Adjust n_permutations if it's larger than total_permutations
            if n_permutations > total_permutations:
                exact_test = True
                n_permutations = int(total_permutations)
            else:
                n_permutations = 9999
                exact_test = False
        else:
            # If num_pattern is 15, the total combinations are 1.5e8, which already is much larger than 9999. So we skip calculating the factorials for 15+ to save compute time. 
            n_permutations = 9999
            exact_test = False

        pvalues = permutation_test_batch(patterns, controls, n_permutations=n_permutations, exact_test=exact_test)

        critical_value = 0.05
        adjusted_critical_value = critical_value / 5000
        significant_count = np.sum(pvalues < critical_value)
        bonferroni_count = np.sum(pvalues < adjusted_critical_value)

        result = (f'{strength}_{count}_{sample}', significant_count / 1000)
        bonferroni_result = (f'{strength}_{count}_{sample}', bonferroni_count / 1000)
        
        #with open('/media/gambino/students_workdir/nynke/blurry/results.txt', 'a') as f:
        #    f.write(f'{strength}\t{count}\t{sample}\t{significant_count / 1000}\t{bonferroni_count / 1000}\n')
        
        logging.info('Finished compute_power_permutation with params: %s', params)
        return result, bonferroni_result

    except ValueError as e:
        if str(e) == "Cannot take a larger sample than population when 'replace=False'":
            print(f"Error for parameters: {params}")
            return None
        else:
            raise e


def chamfer_L1_distance(distance_matrix, index_list):
    len_pattern = len(index_list) // 2
    distances_1_to_2 = np.min(distance_matrix[np.ix_(index_list[:len_pattern], index_list[len_pattern:])], axis=1)
    distances_2_to_1 = np.min(distance_matrix[np.ix_(index_list[len_pattern:], index_list[:len_pattern])], axis=1)
    return np.mean(distances_1_to_2) + np.mean(distances_2_to_1)

def chamfer_L1_distance_batch(distance_matrices, index_list):
    """
    Computes the Chamfer L1 distances for a batch of distance matrices
    Args:
        distance_matrix (ndarray): A 3D numpy array representing the pairwise distance matrix for many samples.
        index_lists (ndarray): A 1D numpy array represents a permutation of indices.
    Returns:
        ndarray: A 1D array of Chamfer L1 distances for all samples.
    """
    len_pattern = index_list.shape[0] // 2
    n_power_iterations=distance_matrices.shape[0]
    # Create a 3D index array for pattern to control distances
    idx2 = index_list[:len_pattern]
    idx3 = index_list[len_pattern:]
    distances_1_to_2 = np.min(distance_matrices[np.ix_(np.arange(n_power_iterations), idx2, idx3)], axis=2)
    distances_2_to_1 = np.min(distance_matrices[np.ix_(np.arange(n_power_iterations), idx3, idx2)], axis=2)
    return np.mean(distances_1_to_2, axis=1) + np.mean(distances_2_to_1, axis=1)

def permutation_test_batch(patterns, controls, n_permutations=9999, exact_test=False):
    combined = np.concatenate((patterns, controls), axis=1)
    len_combined = combined.shape[1]
    indices_list = np.arange(len_combined)

    logging.info('Started permutation_test_batch')

    distance_matrices = np.array([squareform(pdist(combined[i], metric='cityblock')) for i in range(1000)])
    logging.info('computed all distance matrices')

    # Compute observed statistics for all samples
    #vectorized_chamfer_L1_distance = np.vectorize(chamfer_L1_distance, signature='(n,m)->()')
    #observed_statistics = vectorized_chamfer_L1_distance(distance_matrices, indices_list)

    observed_statistics = np.array([chamfer_L1_distance(distance_matrix, indices_list) for distance_matrix in distance_matrices])
    logging.info('computed all observed statistics')
    
    # Generate index permutations
    index_lists = np.array([np.random.permutation(len_combined) for _ in range(n_permutations)])
    #vectorized_permutation = np.vectorize(np.random.permutation, signature='()->(n)')
    #index_lists = vectorized_permutation(len_combined, size=n_permutations)

    logging.info('Generated indices permutations')

    # Function to compute Chamfer distances for a single permutation and sample. Reuse the index permutations for all 1000 power analysis runs.
    #def compute_chamfer_distances(index_list):
    #    return chamfer_L1_distance_batch(distance_matrices, index_lists)

    # Compute Chamfer distances for all permutations and samples using multithreading
    with concurrent.futures.ThreadPoolExecutor() as executor:
        chamfer_distances = np.array(list(executor.map(lambda index_list: chamfer_L1_distance_batch(distance_matrices, index_list), index_lists)))
    

    # Compute Chamfer distances for all permutations and samples
    #chamfer_distances = np.array([chamfer_L1_distance_batch(distance_matrices, index_list) for index_list in index_lists])
    logging.info('Computed all chamfer distances')

    # Calculate p-values (one-sided test cause only interested in larger distances than H0)
    eps = np.finfo(observed_statistics.dtype).eps * 100
    gamma = np.abs(eps * observed_statistics)
    cmps_greater = chamfer_distances >= observed_statistics[:, None] - gamma[:, None]
    adjustment = 0 if exact_test else 1
    pvalues_greater = (cmps_greater.sum(axis=1) + adjustment) / (n_permutations + adjustment)
    return pvalues_greater

In [75]:
pattern_adata = subset_power_analysis(adata_test, mixed_patterns = True, pattern_strength= strength, rna_count = count, sample_size = sample, random_seed=False)
control_adata = subset_power_analysis(adata_test, pattern = 'random', mixed_patterns = False, rna_count = count, sample_size = sample, random_seed=False)

In [95]:
def compute_power_permutation(params):
    try:
        strength, count, sample = params
        significant_count = 0
        bonferroni_count = 0
        # Given that random patterns have only 1442 cells simulated, we don't calculate the power for these above 1400 so that we don't need to sample with replacement
        if (count == '0-10' and sample > 1400) or (count == '10-30' and sample > 5100):
            return (f'{strength}_{count}_{sample}', -1)
    
        for i in range(1000):
            # sample new gene. No random seed so that every time a different "gene" is sampled.
            pattern = subset_power_analysis(adata_test, mixed_patterns = True, pattern_strength= strength, rna_count = count, sample_size = sample, random_seed=False)
            control = subset_power_analysis(adata_test, pattern = 'random', mixed_patterns = False, rna_count = count, sample_size = sample, random_seed=False)

            # Calculate null distribution and pvalue
            pvalue = test_permutation(pattern.obsm["latent"], control.obsm["latent"], n_permutations=9999, return_distances = False)
            
            critical_value = 0.05
            adjusted_critical_value = critical_value / 5000
            if pvalue < critical_value:
                significant_count += 1
            if pvalue < adjusted_critical_value:
                bonferroni_count += 1
        result = (f'{strength}_{count}_{sample}', significant_count/1000)
        bonferroni_result = (f'{strength}_{count}_{sample}', bonferroni_count/1000)

        return result, bonferroni_result
    except ValueError as e:
        if str(e) == "Cannot take a larger sample than population when 'replace=False'":
            print(f"Error for parameters: {params}")
            return None
        else:
            raise e

In [96]:
combination = ('low', '10-30', 753)
result, bonferroni_result = compute_power_permutation(combination)

In [97]:
print(result)
print(bonferroni_result)

('low_10-30_753', 0.867)
('low_10-30_753', 0.0)


In [34]:
import logging
logging.basicConfig(filename='test.log', filemode='a', format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO)

In [61]:
result, bonferroni_result = compute_power_permutation_test(combination)
print(result)
print(bonferroni_result)

KeyboardInterrupt: 

In [107]:
index_lists.shape

(9999, 1506)

In [113]:
len_pattern = index_lists.shape[1] // 2
index_lists[:, :len_pattern].shape

(9999, 1506)

In [136]:
indices_pattern = index_lists[:, :len_pattern]
indices_control = index_lists[:, len_pattern:]

In [137]:
indices_pattern.shape

(9999, 753)

In [140]:
len_pattern = index_lists.shape[1] // 2
indices_pattern = index_lists[:, :len_pattern]
indices_control = index_lists[:, len_pattern:]

index_list = index_lists[0]
np.min(distance_matrix[np.ix_(index_list[len_pattern:], index_list[:len_pattern])], axis=1).shape

(753,)

In [157]:
indices_pattern.shape

(9999, 753)

In [144]:
distance_matrix_3d = distance_matrix[:, np.newaxis]

In [156]:
distance_matrix_3d.shape

(1506, 1, 1506)

In [150]:
subset_distance_matrices = distance_matrix_3d[indices_control, :, indices_pattern]
subset_distance_matrices.shape

(9999, 753, 1)

In [152]:
np.argmin(distance_matrix_3d[indices_control, :, indices_pattern], axis=2).shape
I want my distance_matrix_3D to be of the dimensions (1506, 9999, 1506) instead of (1506,1,1506)

(9999, 753)

In [None]:
# Create a 3D distance matrix using broadcasting
distance_matrix_3d = distance_matrix[:, np.newaxis]

# Subset the 3D distance matrix using advanced indexing
subset_distance_matrices = distance_matrix_3d[indices_control, :, indices_pattern]

In [112]:
distance_matrix.shape

(1506, 1506)

In [164]:
index_list.shape

(1506,)

In [159]:
#compute my chamfer distance for all permutations in one go. Idea is to broadcast the distance matrix to the shape (1506, 9999, 1506) and then subset it & get the nearest neighbours
def chamfer_L1_distance_batch(distance_matrix, index_lists):
    len_pattern = index_lists.shape[1] // 2
    indices_pattern = index_lists[:, :len_pattern]
    indices_control = index_lists[:, len_pattern:]

    distances_1_to_2 = np.min(distance_matrix[:, indices_pattern, indices_control], axis=1)
    distances_2_to_1 = np.min(distance_matrix[:, indices_control, indices_pattern], axis=1)

    return np.mean(distances_1_to_2, axis=1) + np.mean(distances_2_to_1, axis=1)

In [165]:
indices_pattern = index_list[:, :len_pattern]
indices_control = index_list[:, len_pattern:]
distance_matrix[:, indices_pattern, indices_control]

IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

In [None]:
len_pattern = index_lists.shape[1] // 2
indices_pattern = index_lists[:, :len_pattern]
indices_control = index_lists[:, len_pattern:]

distances_1_to_2 = np.min(distance_matrix[:, indices_pattern, indices_control], axis=2)
distances_2_to_1 = np.min(distance_matrix[:, indices_control, indices_pattern], axis=2)

return np.mean(distances_1_to_2, axis=1) + np.mean(distances_2_to_1, axis=1)

In [32]:
import concurrent.futures
def compute_power_permutation_test(params):
    logging.info('Starting compute_power_permutation with params: %s', params)
    try:
        strength, count, sample = params
        significant_count = 0
        bonferroni_count = 0
        if (count == '0-10' and sample > 1400) or (count == '10-30' and sample > 5100):
            return (f'{strength}_{count}_{sample}', -1)

        # Generate 1000 samples of pattern and control
        patterns = np.array([subset_power_analysis(adata_test, mixed_patterns=True, pattern_strength=strength, rna_count=count, sample_size=sample, random_seed=False).obsm["latent"] for _ in range(1000)])
        controls = np.array([subset_power_analysis(adata_test, pattern='random', mixed_patterns=False, rna_count=count, sample_size=sample, random_seed=False).obsm["latent"] for _ in range(1000)])
        logging.info('Generated the 1000 samples of pattern and control')
        # Count max number of permutations with Combination rule nCr, where r is the pattern size
        if sample < 15:
            total_permutations = comb(sample*2, sample) # built in implementation of nCr rule.
            # Adjust n_permutations if it's larger than total_permutations
            if n_permutations > total_permutations:
                exact_test = True
                n_permutations = int(total_permutations)
            else:
                n_permutations = 9999
                exact_test = False
        else:
            # If num_pattern is 15, the total combinations are 1.5e8, which already is much larger than 9999. So we skip calculating the factorials for 15+ to save compute time. 
            n_permutations = 9999
            exact_test = False

        pvalues = permutation_test_batch(patterns, controls, n_permutations=n_permutations, exact_test=exact_test)

        critical_value = 0.05
        adjusted_critical_value = critical_value / 5000
        significant_count = np.sum(pvalues < critical_value)
        bonferroni_count = np.sum(pvalues < adjusted_critical_value)

        result = (f'{strength}_{count}_{sample}', significant_count / 1000)
        bonferroni_result = (f'{strength}_{count}_{sample}', bonferroni_count / 1000)
        
        #with open('/media/gambino/students_workdir/nynke/blurry/results.txt', 'a') as f:
        #    f.write(f'{strength}\t{count}\t{sample}\t{significant_count / 1000}\t{bonferroni_count / 1000}\n')
        
        logging.info('Finished compute_power_permutation with params: %s', params)
        return result, bonferroni_result

    except ValueError as e:
        if str(e) == "Cannot take a larger sample than population when 'replace=False'":
            print(f"Error for parameters: {params}")
            return None
        else:
            raise e


def chamfer_L1_distance_batch(distance_matrix, index_lists):
    len_pattern = index_lists.shape[1] // 2
    indices_pattern = index_lists[:, :len_pattern]
    indices_control = index_lists[:, len_pattern:]

    distances_1_to_2 = np.min(distance_matrix[:, indices_pattern, indices_control], axis=1)
    distances_2_to_1 = np.min(distance_matrix[:, indices_control, indices_pattern], axis=1)

    return np.mean(distances_1_to_2, axis=1) + np.mean(distances_2_to_1, axis=1)


def permutation_test_batch(patterns, controls, n_permutations=9999, exact_test=False):
    combined = np.concatenate((patterns, controls), axis=1)
    len_combined = combined.shape[1]
    indices_list = np.arange(len_combined)

    logging.info('Started permutation_test_batch')

    # Compute distance matrices for all combined samples
    distance_matrices = np.array([squareform(pdist(combined[i], metric='cityblock')) for i in range(1000)])
    logging.info('computed all distance matrices')

    # Compute observed statistics for all samples
    observed_statistics = np.array([chamfer_L1_distance(distance_matrix, indices_list) for distance_matrix in distance_matrices])
    logging.info('computed all observed statistics')
    
    # Generate index permutations
    index_lists = np.array([np.random.permutation(len_combined) for _ in range(n_permutations)])
    logging.info('Generated indices permutations')

    # Function to compute Chamfer distances for a single permutation and sample. Reuse the index permutations for all 1000 power analysis runs.
    def compute_chamfer_distances(distance_matrix):
        return chamfer_L1_distance_batch(distance_matrix, index_lists)

    # Compute Chamfer distances for all permutations and samples using multithreading
    with concurrent.futures.ThreadPoolExecutor() as executor:
        chamfer_distances = np.array(list(executor.map(compute_chamfer_distances, distance_matrices)))
    logging.info('Computed all chamfer distances')

    # Compute Chamfer distances for all permutations and samples
    #chamfer_distances = np.array([chamfer_L1_distance_batch(distance_matrix, index_lists) for distance_matrix in distance_matrices])

    # Calculate p-values (one-sided test cause only interested in larger distances than H0)
    eps = np.finfo(observed_statistics.dtype).eps * 100
    gamma = np.abs(eps * observed_statistics)
    cmps_greater = chamfer_distances >= observed_statistics[:, None] - gamma[:, None]
    adjustment = 0 if exact_test else 1
    pvalues_greater = (cmps_greater.sum(axis=1) + adjustment) / (n_permutations + adjustment)
    return pvalues_greater

In [76]:
pattern = pattern_adata.obsm['latent']
control = control_adata.obsm['latent']

In [77]:
combined = np.concatenate([pattern, control])
len_combined = len(combined)

In [88]:
distance_matrix = squareform(pdist(combined, metric='cityblock'))
distance_matrix.shape

(1506, 1506)

In [79]:
n_permutations = 9999

In [89]:
index_lists = np.apply_along_axis(np.random.permutation, 1, np.tile(list(range(len_combined)), (n_permutations, 1)))

In [85]:
chamfer_distances = np.array([chamfer_L1_distance_test(distance_matrix, index_list) for index_list in index_lists])

In [None]:
chamfer_distances = np.array([chamfer_L1_distance_test(distance_matrix, index_list) for index_list in index_lists])

In [92]:
test_permutation(pattern, control, n_permutations = 9999, return_distances = False)

0.0008

In [93]:
permutation_test(pattern, control, n_permutations = 9999, return_distances = False)

1.0

In [41]:
np.random.permutation(1500)

array([ 188, 1319,  170, ...,  550,  136, 1134])

In [91]:
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.special import comb
from concurrent.futures import ThreadPoolExecutor, as_completed

def chamfer_L1_distance(distance_matrix, index_list):
    len_pattern = len(index_list) // 2
    distances_1_to_2 = np.min(distance_matrix[np.ix_(index_list[:len_pattern], index_list[len_pattern:])], axis=1)
    distances_2_to_1 = np.min(distance_matrix[np.ix_(index_list[len_pattern:], index_list[:len_pattern])], axis=1)
    return np.mean(distances_1_to_2) + np.mean(distances_2_to_1)

def test_permutation(pattern, control, n_permutations: int = 9999, return_distances: bool = False):
    combined = np.concatenate([pattern, control])
    distance_matrix = squareform(pdist(combined, metric='cityblock'))
    len_combined = len(combined)
    num_pattern = len(pattern)
    observed_statistic = chamfer_L1_distance(distance_matrix, list(range(len_combined)))

    if num_pattern < 15:
        total_permutations = comb(len_combined, num_pattern)
        if n_permutations > total_permutations:
            exact_test = True
            n_permutations = int(total_permutations)
        else:
            exact_test = False
    else:
        exact_test = False

    index_lists = np.apply_along_axis(np.random.permutation, 1, np.tile(list(range(len_combined)), (n_permutations, 1)))

    chamfer_distances = []
    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(chamfer_L1_distance, distance_matrix, index_list) for index_list in index_lists]
        for future in as_completed(futures):
            chamfer_distances.append(future.result())

    chamfer_distances = np.array(chamfer_distances)

    eps = (0 if not np.issubdtype(observed_statistic.dtype, np.inexact)
           else np.finfo(observed_statistic.dtype).eps * 100)
    gamma = np.abs(eps * observed_statistic)
    cmps_greater = chamfer_distances >= observed_statistic - gamma
    adjustment = 0 if exact_test else 1
    pvalues_greater = (cmps_greater.sum() + adjustment) / (n_permutations + adjustment)
    p_value = pvalues_greater

    if return_distances:
        return p_value, observed_statistic, chamfer_distances
    else:
        return p_value


In [52]:
permutation_test(pattern, control, n_permutations = 9999, return_distances = False)

0.0147

In [50]:
np.apply_along_axis(np.random.permutation, 1, np.tile(index_list, (10, 1)))

array([[ 150,  678,   94, ...,  423, 1409,  203],
       [  21, 1303, 1187, ..., 1477, 1179, 1031],
       [1180, 1205, 1378, ...,  698,  328, 1013],
       ...,
       [ 521,  511, 1279, ..., 1472,  801,  105],
       [ 841,  725,  301, ...,  173, 1204,  174],
       [  23, 1043,  774, ...,  952,  147,  840]])

In [82]:
 #compute distance matrix of combined (pdist) only once (combined 2000x2000)
    #generate a list 1000x(len(array)) indices (ask chatgpt for vectorized way)
    #for all indices combos in the list (will run 1000 times):    
        # an indices of 1000 for pattern and 1000 for control
        # change chamfer_L1_distance_cpu(permuted_pattern, permuted_control) to
        # chamfer_L1_distance_cpu(pdist, indices) and inside we will filter distances 1 to 2 and distances 2 to 1
        # and continue chamfer as before

def chamfer_L1_distance(distance_matrix, index_list):
    len_pattern = len(index_list) // 2
    # subset the distance matrix with the indices of both point clouds, and get the nearest neighbor for each point from point cloud 1 in point cloud 2 and vice versa
    distances_1_to_2 = np.min(distance_matrix[index_list[:len_pattern]][index_list[len_pattern:]], axis=1)
    distances_2_to_1 = np.min(distance_matrix[np.ix_(index_list[len_pattern:], index_list[:len_pattern])], axis=1)

    # Compute the Chamfer distance
    return np.mean(distances_1_to_2) + np.mean(distances_2_to_1)


def permutation_test(pattern, control, n_permutations: int = 9999, return_distances: bool = False):
    combined = np.concatenate([pattern, control])
    distance_matrix = squareform(pdist(combined, metric='cityblock'))
    len_combined = len(combined)
    num_pattern = len(pattern)
    observed_statistic = chamfer_L1_distance(distance_matrix, list(range(len_combined)))

    # Count max number of permutations with Combination rule nCr, where r is the pattern size
    if num_pattern < 15:
        total_permutations = comb(len_combined, num_pattern) # built in implementation of nCr rule.

        # Adjust n_permutations if it's larger than total_permutations
        if n_permutations > total_permutations:
            exact_test = True
            n_permutations = int(total_permutations)
        else:
            exact_test = False
    else:
        # If num_pattern is 15, the total combinations are 1.5e8, which already is much larger than 9999. So we skip calculating the factorials for 15+ to save compute time. 
        exact_test = False

    # Permute the indices of the combined array n_permutations times and calculate the Chamfer distance for each permutation
    index_lists = np.apply_along_axis(np.random.permutation, 1, np.tile(list(range(len_combined)), (n_permutations, 1)))
    chamfer_distances = np.array([chamfer_L1_distance(distance_matrix, index_list) for index_list in index_lists])
    
    #These functions come from scipy.stats.permutation_test(). They have now been integrated in my main function in line to improve the efficiency
    eps =  (0 if not np.issubdtype(observed_statistic.dtype, np.inexact)
        else np.finfo(observed_statistic.dtype).eps*100)
    gamma = np.abs(eps * observed_statistic)
    cmps_greater = chamfer_distances >= observed_statistic - gamma
    # +1 is added to pvalues to add the observed value into the hypothetical population to make the pvalue more conservative. If it is an exact test, will use the true pvalue.
    adjustment = 0 if exact_test else 1
    pvalues_greater = (cmps_greater.sum() + adjustment) / (n_permutations + adjustment)
    # I do a 1-tailed test because I only care if the observed statistic has a larger chamfer distance than the H0 population.
    p_value = pvalues_greater
    
    if return_distances == True:
        return p_value, observed_statistic, chamfer_distances
    else:  
        return p_value

In [None]:
pvalue, observed_statistic, permuted_statistics = permutation_test_cpu(pattern.obsm["latent"], control.obsm["latent"], n_permutations=9999)