In [None]:
CONFIG_FILE_PATH = "config/config.yml"

In [None]:
from pathlib import Path

import yaml
from pyarrow.parquet import ParquetFile
import pyarrow as pa

import pandas as pd

In [None]:
# Mapping of metadata keys to filename prefixes and data types
FILENAME_METADATA_MAPPING = {
    "plate": ["P-", str],
    "well": ["W-", str],
    "tile": ["T-", int],
    "cycle": ["C-", int],
    "gene": ["G-", str],
    "sgrna": ["SG-", str],
    "channel": ["CH-", str],
    "dataset": ["DT-", str],
}


def get_filename(data_location: dict, info_type: str, file_type: str) -> str:
    """Generate a structured filename based on data location, information type, and file type.

    Args:
        data_location (dict): Dictionary containing location info like well, tile, and cycle.
        info_type (str): Type of information (e.g., 'cell_features', 'sbs_reads').
        file_type (str): File extension/type (e.g., 'tsv', 'parquet', 'tiff').

    Returns:
        str: Structured filename.
    """
    parts = []

    for metadata_key, metadata_value in data_location.items():
        if metadata_key in FILENAME_METADATA_MAPPING:
            prefix, _ = FILENAME_METADATA_MAPPING[metadata_key]
            parts.append(f"{prefix}{metadata_value}")
        else:
            print(f"Unknown metadata key: {metadata_key}")

    prefix = "_".join(parts)
    filename = (
        f"{prefix}__{info_type}.{file_type}" if prefix else f"{info_type}.{file_type}"
    )

    return filename


def load_parquet_subset(full_df_fp, n_rows=50000):
    """Load a fixed number of rows from an parquet file without loading entire file into memory.

    Args:
        full_df_fp (str): Path to parquet file.
        n_rows (int): Number of rows to get.

    Returns:
        pd.DataFrame: Subset of the data with combined blocks.
    """
    print(f"Reading first {n_rows:,} rows from {full_df_fp}")

    # read the first n_rows of the file path
    df = ParquetFile(full_df_fp)
    row_subset = next(df.iter_batches(batch_size=n_rows))
    df = pa.Table.from_batches([row_subset]).to_pandas()

    return df

In [None]:
TEST_PLATE = 1
TEST_WELL = "A1"

In [None]:
# load config file and determine root path
with open(CONFIG_FILE_PATH, "r") as config_file:
    config = yaml.safe_load(config_file)
ROOT_FP = Path(config["all"]["root_fp"])

# Load subset of data
# Takes ~1 minute
merge_final_fp = (
    ROOT_FP
    / "merge"
    / "parquets"
    / get_filename({"plate": TEST_PLATE, "well": TEST_WELL}, "merge_final", "parquet")
)
cell_data = load_parquet_subset(merge_final_fp)

display(cell_data)

In [None]:
print("First 20 columns; use to set parameters below.")
for index, col in enumerate(cell_data.columns[:20]):
    print(index, col)

In [None]:
def perturbation_filter(
    cell_data,
    perturbation_name_col,
    perturbation_multi_col=None,
    filter_single_pert=False,
):
    """Clean cell data by removing cells without perturbation assignments and optionally filtering for single-gene cells.

    Args:
        cell_data (pd.DataFrame): Raw dataframe containing cell measurements.
        perturbation_name_col (str): Column name containing perturbation assignments.
        perturbation_multi_col (str): If not None, only keep cells with perturbation_multi_col=True.

    Returns:
        pd.DataFrame: Cleaned dataframe.
    """
    # Remove cells without perturbation assignments
    clean_cell_data = cell_data[cell_data[perturbation_name_col].notna()].copy()
    print(f"Found {len(clean_cell_data)} cells with assigned perturbations")

    if filter_single_pert:
        # Filter for single-gene cells if requested
        clean_cell_data = clean_cell_data[
            clean_cell_data[perturbation_multi_col] == True
        ]
        print(f"Kept {len(clean_cell_data)} cells with single gene assignments")
    else:
        # Warn about multi-gene cells if not filtering
        multi_pert_cells = len(
            clean_cell_data[clean_cell_data[perturbation_multi_col] == False]
        )
        if multi_pert_cells > 0:
            print(
                f"WARNING: {multi_pert_cells} cells have multiple perturbation assignments"
            )

    return clean_cell_data

In [None]:
PERTURBATION_NAME_COL = "gene_symbol_0"
PERTURBATION_MULTI_COL = "mapped_single_gene"
FILTER_SINGLE_PERT = False

perturbation_filtered = perturbation_filter(
    cell_data, PERTURBATION_NAME_COL, PERTURBATION_MULTI_COL, FILTER_SINGLE_PERT
)
print(f"Unique populations: {perturbation_filtered[PERTURBATION_NAME_COL].nunique()}")
perturbation_filtered

In [None]:
FEATURE_START_IDX = 17

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns

def visualize_na_matrix(df):
    """
    Creates a visualization matrix showing which columns have NA values and in which rows.
    No color bar is displayed.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        The DataFrame to analyze for NA values
        
    Returns:
    --------
    matplotlib.figure.Figure or None
        The figure object if NAs are found, None otherwise
    """
    # Get only columns with at least one NA
    cols_with_na = df.columns[df.isna().any()].tolist()
    
    if not cols_with_na:
        print("No columns with NA values found in the DataFrame.")
        return None
    
    # Create a smaller DataFrame with only columns containing NAs
    na_df = df[cols_with_na].isna()
    
    # Create the heatmap
    plt.figure(figsize=(15, 10))
    plt.title(f"NA Values Matrix ({len(cols_with_na)} columns with missing values)")
    
    # Create heatmap - True (NA) values will be colored, without color bar
    ax = sns.heatmap(na_df, cmap='viridis', cbar=False)
    
    # Display column names on x-axis, rotated for readability
    plt.xticks(rotation=90)
    plt.tight_layout()
    
    # Add summary information
    na_counts = df[cols_with_na].isna().sum()
    na_percent = (na_counts / len(df)) * 100
    
    print(f"Columns with high NA value percent:")
    for col, count, pct in zip(cols_with_na, na_counts, na_percent):
        if pct > 10:
            print(f"  - {col}: {count} NAs ({pct:.2f}%)")
    
    return plt.gcf()

fig = visualize_na_matrix(perturbation_filtered)
plt.show()

In [None]:
DROP_COLS_THRESHOLD = 0.1

In [None]:
from sklearn.covariance import EllipticEnvelope


def intensity_filter(
    cell_data, feature_start_idx, channel_names=None, contamination=0.01
) -> pd.DataFrame:
    """
    Uses EllipticEnvelope to filter outliers by channel intensities.

    Derived from Recursion's EFAAR pipeline: https://github.com/recursionpharma/EFAAR_benchmarking/blob/60df3eb267de3ba13b95f720b2a68c85f6b63d14/efaar_benchmarking/efaar.py#L295

    Args:
        cell_data (pd.DataFrame): Cell data dataframe.
        feature_start_idx (int): Index of the first feature column.
        channel_names (list[str], optional): A list of channel names to use for intensity filtering. Defaults to None.
        contamination (float, optional): The proportion of outliers to expect. Defaults to 0.01.
    Returns:
        pd.DataFrame: Filtered cell data dataframe.
    """
    # Identify feature cols
    feature_cols = perturbation_filtered.columns[feature_start_idx:].tolist()

    # Determine intensity columns
    intensity_cols = [
        col
        for col in feature_cols
        if any(col.endswith(f"_{channel}_mean") for channel in channel_names)
    ]

    # Fit EllipticEnvelope to intensity cols and get mask
    mask = EllipticEnvelope(contamination=contamination, random_state=42).fit_predict(
        cell_data[intensity_cols]
    )

    # Return filtered cell data
    return cell_data[mask == 1].reset_index(drop=True)

# Load channel names
channel_names = config["phenotype"]["channel_names"]

intensity_filtered = intensity_filter(
    perturbation_filtered, FEATURE_START_IDX, channel_names
)
intensity_filtered

In [None]:
from sklearn.impute import KNNImputer

def missing_values_filter(cell_data, feature_start_idx, impute=True, drop_rows=False, drop_cols=False, drop_cols_threshold=None):
    """Filter cell data by handling missing values through dropping or imputation.

    Args:
        cell_data (pd.DataFrame): Raw dataframe containing cell measurements.
        feature_start_idx (int): Index of the first feature column.
        impute (bool): Whether to impute remaining missing values after dropping. Defaults to True.
        drop_rows (bool): Whether to drop all rows with any missing values. Defaults to False.
        drop_cols (bool): Whether to drop all columns with any missing values. Defaults to False.
        drop_cols_threshold (float, optional): If provided, drops columns with NaN proportion >= threshold.
                                              This overrides drop_cols if both are specified.
                                              Range: 0.0-1.0. Defaults to None.

    Returns:
        pd.DataFrame: Filtered dataframe with handled missing values.
    """
    # Get features
    metadata = cell_data.iloc[:, :feature_start_idx].copy()
    features = cell_data.iloc[:, feature_start_idx:].copy()
    
    # Get columns with missing values
    cols_with_na = features.columns[features.isna().any()].tolist()
    
    if not cols_with_na:
        return cell_data
    
    # Perform dropping operations if requested
    if drop_rows:
        # Drop rows with any missing values
        original_row_count = features.shape[0]
        features.dropna(axis=0, inplace=True)
        print(f"Dropped {original_row_count - features.shape[0]} rows with missing values")
        
        # Update metadata to match remaining rows
        metadata = metadata.loc[features.index]
    
    # Handle column dropping based on parameters
    if drop_cols_threshold is not None:
        # Calculate proportion of NaN values in each column
        na_proportions = features.isna().mean()
        
        # Identify columns to drop based on threshold
        cols_to_drop = na_proportions[na_proportions >= drop_cols_threshold].index.tolist()
        
        if cols_to_drop:
            print(f"Dropping {len(cols_to_drop)} columns with ≥{drop_cols_threshold*100}% missing values")
            features.drop(columns=cols_to_drop, inplace=True)
    
    if drop_cols:
        # Drop all columns with any missing values
        print(f"Dropping all {len(cols_with_na)} columns with any missing values")
        features.drop(columns=cols_with_na, inplace=True)
    
    # Impute remaining missing values if requested
    if impute:
        # Get updated list of columns with missing values
        remaining_cols_with_na = features.columns[features.isna().any()].tolist()
        
        if remaining_cols_with_na:
            print(f"Imputing {len(remaining_cols_with_na)} columns with remaining missing values")
            
            # Store index for later reconstruction
            index = features.index
            
            # Apply imputation only to columns with missing values
            imputer = KNNImputer(n_neighbors=5)
            features[remaining_cols_with_na] = pd.DataFrame(
                imputer.fit_transform(features[remaining_cols_with_na]),
                columns=remaining_cols_with_na,
                index=index
            )
    
    # Combine metadata and features
    filtered_data = pd.concat([metadata, features], axis=1)
    
    return filtered_data


missing_values_filtered = missing_values_filter(
    intensity_filtered, FEATURE_START_IDX, drop_cols_threshold=DROP_COLS_THRESHOLD
)
missing_values_filtered

In [None]:
BATCH_COLS = ["plate", "well"]

In [None]:
def prepare_alignment_data(cell_data, batch_cols, feature_start_idx):
    """Prepare batch values and split metadata and feature DataFrames.

    Args:
        cell_data (pd.DataFrame): Input DataFrame containing metadata and features.
        batch_cols (list): List of column names used to generate batch values.
        feature_start_idx (int): Index where feature columns start.

    Returns:
        tuple: metadata (pd.DataFrame), features (pd.DataFrame)
    """
    # Create batch values
    batch_values = cell_data[batch_cols[0]].astype(str)
    for col in batch_cols[1:]:
        batch_values = batch_values + "_" + cell_data[col].astype(str)

    # Add batch values to metadata
    metadata = cell_data.iloc[:, :feature_start_idx].copy()
    metadata["batch_values"] = batch_values

    # Extract feature data
    features = cell_data.iloc[:, feature_start_idx:].copy()

    return features, metadata

features, metadata = prepare_alignment_data(missing_values_filtered, BATCH_COLS, FEATURE_START_IDX)

display(metadata)
display(features)

In [None]:
# Adapted from Recurion's EFAAR
# code: https://github.com/recursionpharma/EFAAR_benchmarking/blob/trunk/efaar_benchmarking/efaar.py
# paper: https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1012463

import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from scipy import linalg

def embed_by_pca(
    features: np.ndarray,
    metadata: pd.DataFrame = None,
    variance_or_ncomp=128,
    batch_col: str | None = None,
) -> np.ndarray:
    """
    Embed the whole input data using principal component analysis (PCA).
    Note that we explicitly center & scale the data (by batch) before an embedding operation with `PCA`.
    Centering and scaling is done by batch if `batch_col` is not None, and on the whole data otherwise.
    Also note that `PCA` transformer also does mean-centering on the whole data prior to the PCA operation.

    Args:
        features (np.ndarray): Features to transform
        metadata (pd.DataFrame): Metadata. Defaults to None.
        variance_or_ncomp (float, optional): Variance or number of components to keep after PCA.
            Defaults to 128 (n_components). If between 0 and 1, select the number of components such that
            the amount of variance that needs to be explained is greater than the percentage specified.
            If 1, a single component is kept, and if None, all components are kept.
        batch_col (str, optional): Column name for batch information. Defaults to None.
    Returns:
        np.ndarray: Transformed data using PCA.
    """
    features = features.copy()
    features = centerscale_by_batch(features, metadata, batch_col)
    features = PCA(variance_or_ncomp).fit_transform(features)
    return features

def tvn_on_controls(
    embeddings: np.ndarray,
    metadata: pd.DataFrame,
    pert_col: str,
    control_key: str,
    batch_col: str | None = None,
) -> np.ndarray:
    """
    Apply TVN (Typical Variation Normalization) to the data based on the control perturbation units.
    Note that the data is first centered and scaled based on the control units.

    Args:
        embeddings (np.ndarray): The embeddings to be normalized.
        metadata (pd.DataFrame): The metadata containing information about the samples.
        pert_col (str): The column name in the metadata DataFrame that represents the perturbation labels.
        control_key (str): The control perturbation label.
        batch_col (str, optional): Column name in the metadata DataFrame representing the batch labels
            to be used for CORAL normalization. Defaults to None.

    Returns:
        np.ndarray: The normalized embeddings.
    """
    embeddings = embeddings.copy()
    embeddings = centerscale_on_controls(embeddings, metadata, pert_col, control_key)
    ctrl_ind = metadata[pert_col] == control_key
    embeddings = PCA().fit(embeddings[ctrl_ind]).transform(embeddings)
    embeddings = centerscale_on_controls(embeddings, metadata, pert_col, control_key, batch_col)
    target_cov = np.cov(embeddings[ctrl_ind], rowvar=False, ddof=1) + 0.5 * np.eye(embeddings.shape[1])
    if batch_col is not None:
        batches = metadata[batch_col].unique()
        for batch in batches:
            batch_ind = metadata[batch_col] == batch
            batch_control_ind = batch_ind & (metadata[pert_col] == control_key)
            source_cov = np.cov(embeddings[batch_control_ind], rowvar=False, ddof=1) + 0.5 * np.eye(embeddings.shape[1])
            embeddings[batch_ind] = np.matmul(embeddings[batch_ind], linalg.fractional_matrix_power(source_cov, -0.5))
            embeddings[batch_ind] = np.matmul(embeddings[batch_ind], linalg.fractional_matrix_power(target_cov, 0.5))
    return embeddings

def centerscale_by_batch(
    features: np.ndarray, metadata: pd.DataFrame = None, batch_col: str | None = None
) -> np.ndarray:
    """
    Center and scale the input features by each batch. Not using any controls at all.
    We are using this prior to embedding high-dimensional data with PCA.

    Args:
        features (np.ndarray): Input features to be centered and scaled.
        metadata (pd.DataFrame): Metadata information for the input features.
        batch_col (str): Name of the column in metadata that contains batch information.

    Returns:
        np.ndarray: Centered and scaled features.
    """
    features = features.copy()
    if batch_col is None:
        features = StandardScaler().fit_transform(features)
    else:
        if metadata is None:
            raise ValueError("metadata must be provided if batch_col is not None")
        batches = metadata[batch_col].unique()
        for batch in batches:
            ind = metadata[batch_col] == batch
            features[ind, :] = StandardScaler().fit_transform(features[ind, :])
    return features

def centerscale_on_controls(
    embeddings: np.ndarray,
    metadata: pd.DataFrame,
    pert_col: str,
    control_key: str,
    batch_col: str | None = None,
) -> np.ndarray:
    """
    Center and scale the embeddings on the control perturbation units in the metadata.
    If batch information is provided, the embeddings are centered and scaled by batch.

    Args:
        embeddings (numpy.ndarray): The embeddings to be aligned.
        metadata (pandas.DataFrame): The metadata containing information about the embeddings.
        pert_col (str, optional): The column in the metadata containing perturbation information.
        control_key (str, optional): The key for non-targeting controls in the metadata.
        batch_col (str, optional): Column name in the metadata representing the batch labels.
            Defaults to None.
    Returns:
        numpy.ndarray: The aligned embeddings.
    """
    embeddings = embeddings.copy()
    if batch_col is not None:
        batches = metadata[batch_col].unique()
        for batch in batches:
            batch_ind = metadata[batch_col] == batch
            batch_control_ind = batch_ind & (metadata[pert_col] == control_key)
            embeddings[batch_ind] = StandardScaler().fit(embeddings[batch_control_ind]).transform(embeddings[batch_ind])
        return embeddings

    control_ind = metadata[pert_col] == control_key
    return StandardScaler().fit(embeddings[control_ind]).transform(embeddings)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA

def pca_variance_plot(features, variance_threshold=0.95, random_state=42):
    """Perform PCA analysis and create an explained variance plot.

    Args:
        feature_data (pd.DataFrame): DataFrame containing features.
        variance_threshold (float): Cumulative variance threshold. Defaults to 0.95.
        random_state (int): Random seed for reproducibility.

    Returns:
        tuple: A tuple containing:
            - pca_df (pd.DataFrame): DataFrame with PCA-transformed data (gene symbols as index).
            - n_components (int): Number of components needed to reach the variance threshold.
            - pca_object (PCA): Fitted PCA object.
            - fig (matplotlib.figure.Figure): Figure object for the explained variance plot.
    """
    # Copy and scale data
    features = features.copy()
    features = centerscale_by_batch(features)

    # Initialize and fit PCA
    pca = PCA(random_state=random_state)
    pca_transformed = pca.fit_transform(features)

    # Create DataFrame with PCA results
    n_components_total = pca_transformed.shape[1]
    pca_df = pd.DataFrame(
        pca_transformed,
        columns=[f"pca_{n}" for n in range(n_components_total)],
    )

    # Find number of components needed for threshold
    cumsum = pca.explained_variance_ratio_.cumsum()
    n_components = np.argwhere(cumsum >= variance_threshold)[0][0] + 1

    # Create variance plot
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(cumsum, "-")
    ax.axhline(
        variance_threshold,
        linestyle="--",
        color="red",
        label=f"{variance_threshold * 100}% Threshold",
    )
    ax.axvline(n_components, linestyle="--", color="blue", label=f"n={n_components}")
    ax.set_ylabel("Cumulative fraction of variance explained")
    ax.set_xlabel("Number of principal components included")
    ax.set_title("PCA Explained Variance Ratio")
    ax.grid(True)
    ax.legend()

    print(
        f"Number of components needed for {variance_threshold * 100}% variance: {n_components}"
    )
    print(f"Shape of input data: {features.shape}")

    # Create threshold-limited version
    pca_df_threshold = pca_df[[f"pca_{i}" for i in range(n_components)]]

    print(f"Shape of PCA transformed and reduced data: {pca_df_threshold.shape}")

    return pca_df_threshold, n_components, pca, fig

pca_df_threshold, n_components, pca, fig = pca_variance_plot(features, variance_threshold=0.95)

In [None]:
PC_COUNT = 432

In [None]:
metadata[PERTURBATION_NAME_COL].value_counts()

In [None]:
CONTROL_KEY = "nontargeting"

In [None]:
pca_embeddings = embed_by_pca(features.values, metadata, variance_or_ncomp=PC_COUNT, batch_col="batch_values")
tvn_normalized = tvn_on_controls(pca_embeddings, metadata, PERTURBATION_NAME_COL, CONTROL_KEY, "batch_values")

tvn_normalized_columns = [f'PCA_{i}' for i in range(tvn_normalized.shape[1])]
tvn_normalized_df = pd.DataFrame(tvn_normalized, index=metadata.index, columns=tvn_normalized_columns)
aligned_cell_data = pd.concat([metadata, tvn_normalized_df], axis=1)

aligned_cell_data