In [1]:
%load_ext watermark
%watermark -v -p numpy,pandas,polars,mlxtend,omegaconf --conda

Python implementation: CPython
Python version       : 3.11.8
IPython version      : 8.22.2

numpy    : 1.26.4
pandas   : 2.2.1
polars   : 0.20.18
mlxtend  : 0.23.1
omegaconf: 2.3.0

conda environment: torch_p11



In [2]:
# Built-in library
from pathlib import Path
import re
import json
from typing import Any, Optional, Union
import logging
import warnings

# Standard imports
import numpy as np
import numpy.typing as npt
from pprint import pprint
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "info": "#76FF7B",
        "warning": "#FBDDFE",
        "error": "#FF0000",
    }
)
console = Console(theme=custom_theme)

# Visualization
import matplotlib.pyplot as plt

# NumPy settings
np.set_printoptions(precision=4)

# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

# Polars settings
pl.Config.set_fmt_str_lengths(1_000)
pl.Config.set_tbl_cols(n=1_000)
pl.Config.set_tbl_rows(n=200)

warnings.filterwarnings("ignore")


# auto reload imports# Built-in library
from pathlib import Path
import re
import json
from typing import Any, Optional, Union
import logging
import warnings

# Standard imports
import numpy as np
import numpy.typing as npt
from pprint import pprint
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "info": "#76FF7B",
        "warning": "#FBDDFE",
        "error": "#FF0000",
    }
)
console = Console(theme=custom_theme)

# Visualization
import matplotlib.pyplot as plt

# NumPy settings
np.set_printoptions(precision=4)

# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

# Polars settings
pl.Config.set_fmt_str_lengths(1_000)
pl.Config.set_tbl_cols(n=1_000)
pl.Config.set_tbl_rows(500)

warnings.filterwarnings("ignore")


# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import norm
from typing import Literal


def calculate_sparse_pearson_corr(
    sparse_matrix: csr_matrix, axis: Literal[0, 1] = 1
) -> np.ndarray:
    """
    Calculate Pearson's correlation for a sparse matrix.

    Parameters
    ----------
    sparse_matrix : scipy.sparse.csr_matrix
        Input sparse matrix.
    axis : Literal[0, 1], optional
        The axis along which to compute the correlation, by default 1.

    Returns
    -------
    np.ndarray
        Correlation matrix.

    Notes
    -----
    The shape of the correlation matrix depends on the input matrix and axis:
    - If axis=0: (n_features, n_features)
    - If axis=1: (n_samples, n_samples)
    """
    if axis not in (0, 1):
        raise ValueError("Axis must be 0 or 1")

    n_rows, n_cols = sparse_matrix.shape

    # Calculate mean along the specified axis
    means = np.asarray(sparse_matrix.mean(axis=axis)).ravel()
    print(f"Means: {means}")

    if axis == 0:
        # For column-wise correlation (features)
        means_matrix = csr_matrix(means).reshape(1, -1)  # Shape (1, n_cols)
        centered = sparse_matrix - means_matrix
        norms = norm(centered, axis=0)  # Column-wise norm
        corr_sparse = (centered.T @ centered) / (n_cols - 1)
    else:
        # For row-wise correlation (samples)
        means_matrix = csr_matrix(means).reshape(-1, 1)  # Shape (n_rows, 1)
        centered = sparse_matrix - means_matrix
        norms = norm(centered, axis=1)  # Row-wise norm
        corr_sparse = (centered @ centered.T) / (n_rows - 1)

    # Convert norms to avoid division by zero
    norms[norms == 0] = 1  # Avoid division by zero

    # Calculate correlation matrix
    corr_sparse = corr_sparse / np.outer(norms, norms)

    # Convert to dense format for final result
    corr = corr_sparse.toarray()

    # Fix any NaN values that occur if a column/row had zero variance
    corr = np.nan_to_num(corr)

    # Ensure diagonal elements are exactly 1
    np.fill_diagonal(corr, 1.0)

    print(f"Correlation matrix shape: {corr.shape}")
    return corr

In [4]:
import dask.array as da
import numpy as np
from scipy.sparse import csr_matrix
from typing import Literal


def pearson_correlation_sparse(
    sparse_matrix: csr_matrix, axis: Literal[0, 1] = 1
) -> np.ndarray:
    """
    Compute the Pearson correlation matrix for a sparse matrix using Dask.

    Parameters
    ----------
    sparse_matrix : scipy.sparse.csr_matrix
        Input sparse matrix.
    axis : Literal[0, 1], optional
        The axis along which to compute the correlation. Default is 1.

    Returns
    -------
    np.ndarray
        Pearson correlation matrix.

    Raises
    ------
    ValueError
        If axis is not 0 or 1.

    Notes
    -----
    The shape of the input sparse_matrix is assumed to be (n_samples, n_features).
    The shape of the output correlation matrix depends on the axis:
    - If axis=0, shape is (n_features, n_features)
    - If axis=1, shape is (n_samples, n_samples)
    """
    # Set random seed for reproducibility
    np.random.seed(0)
    da.random.seed(0)

    if axis not in (0, 1):
        raise ValueError("Axis must be 0 or 1")

    # Convert sparse matrix to Dask array
    dask_data: da.Array = da.from_array(sparse_matrix.toarray(), chunks=(10_000, 3_000))

    if axis == 1:
        # Transpose the data to compute correlation along rows
        dask_data = dask_data.T

    # Center the data
    mean: da.Array = dask_data.mean(axis=0)
    centered_data: da.Array = dask_data - mean

    # Compute sum of squares
    ss: da.Array = (centered_data**2).sum(axis=0)

    # Compute correlation matrix
    corr_matrix: da.Array = da.dot(centered_data.T, centered_data) / da.sqrt(
        da.outer(ss, ss)
    )

    # Convert to NumPy array
    corr_matrix_np: np.ndarray = corr_matrix.compute()

    # Handle potential numerical instabilities
    corr_matrix_np = np.where(
        np.isclose(corr_matrix_np, 1, atol=1e-8), 1, corr_matrix_np
    )
    corr_matrix_np = np.where(
        np.isclose(corr_matrix_np, -1, atol=1e-8), -1, corr_matrix_np
    )

    # Ensure the range is [-1, 1] and set diagonal to 1
    corr_matrix_np = np.clip(corr_matrix_np, -1, 1)
    np.fill_diagonal(corr_matrix_np, 1)

    print(f"{corr_matrix_np.shape = }")
    return corr_matrix_np

In [9]:
import dask.array as da
import numpy as np
from scipy.sparse import csr_matrix
from typing import Literal
from tqdm import tqdm


def pearson_correlation_sparse(
    sparse_matrix: csr_matrix, axis: Literal[0, 1] = 1
) -> np.ndarray:
    """
    Compute the Pearson correlation matrix for a sparse matrix using Dask.

    Parameters
    ----------
    sparse_matrix : scipy.sparse.csr_matrix
        Input sparse matrix.
    axis : Literal[0, 1], optional
        The axis along which to compute the correlation. Default is 1.

    Returns
    -------
    np.ndarray
        Pearson correlation matrix.

    Raises
    ------
    ValueError
        If axis is not 0 or 1.

    Notes
    -----
    The shape of the input sparse_matrix is assumed to be (n_samples, n_features).
    The shape of the output correlation matrix depends on the axis:
    - If axis=0, shape is (n_features, n_features)
    - If axis=1, shape is (n_samples, n_samples)
    """
    # Set random seed for reproducibility
    np.random.seed(0)
    da.random.seed(0)

    if axis not in (0, 1):
        raise ValueError("Axis must be 0 or 1")

    # Convert sparse matrix to Dask array
    dask_data: da.Array = da.from_array(sparse_matrix.toarray(), chunks=(10_000, 3_000))

    if axis == 1:
        # Transpose the data to compute correlation along rows
        dask_data = dask_data.T

    # Center the data
    mean: da.Array = dask_data.mean(axis=0)
    centered_data: da.Array = dask_data - mean

    # Compute sum of squares
    ss: da.Array = (centered_data**2).sum(axis=0)

    # Compute correlation matrix
    corr_matrix: da.Array = da.dot(centered_data.T, centered_data) / da.sqrt(
        da.outer(ss, ss)
    )

    # Add progress bar for Dask computation
    with tqdm(total=1, desc="Computing correlation matrix") as pbar:
        corr_matrix_np: np.ndarray = corr_matrix.compute()
        pbar.update(1)

    # Handle potential numerical instabilities
    corr_matrix_np = np.where(
        np.isclose(corr_matrix_np, 1, atol=1e-8), 1, corr_matrix_np
    )
    corr_matrix_np = np.where(
        np.isclose(corr_matrix_np, -1, atol=1e-8), -1, corr_matrix_np
    )

    # Ensure the range is [-1, 1] and set diagonal to 1
    corr_matrix_np = np.clip(corr_matrix_np, -1, 1)
    np.fill_diagonal(corr_matrix_np, 1)

    print(f"{corr_matrix_np.shape = }")
    return corr_matrix_np

In [10]:
# Example usage
# sparse_data = csr_matrix(np.random.random((500000, 30000)))
sparse_data = csr_matrix(np.random.random((1000, 10)))

# Compute correlation along columns (axis=0)
correlation_matrix = pearson_correlation_sparse(sparse_data, axis=1)
print("Pearson's Correlation Matrix:")
print(correlation_matrix)

# Compute correlation along rows (axis=1)
# correlation_rows = pearson_correlation_sparse(sparse_data, axis=1)
# print("Correlation matrix along rows:", correlation_rows)

Computing correlation matrix: 100%|██████████| 1/1 [00:01<00:00,  1.35s/it]


corr_matrix_np.shape = (1000, 1000)
Pearson's Correlation Matrix:
[[ 1.      0.335   0.5065 ...  0.0618 -0.4755 -0.4165]
 [ 0.335   1.      0.638  ...  0.1838  0.0465 -0.0642]
 [ 0.5065  0.638   1.     ...  0.2043 -0.3435  0.0227]
 ...
 [ 0.0618  0.1838  0.2043 ...  1.      0.1508  0.0838]
 [-0.4755  0.0465 -0.3435 ...  0.1508  1.      0.1891]
 [-0.4165 -0.0642  0.0227 ...  0.0838  0.1891  1.    ]]


In [6]:
np.all(np.abs(correlation_matrix) <= 1)

True