In [1]:
import zarr
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.backends.backend_pdf import PdfPages
import shutil
import pathlib
import zipfile

In [2]:
def create_binary_masks_with_clipped_limits(original_images, contrast_limits_per_channel):
    """
    Create binary masks from original images based on specified contrast limits for each channel.

    Parameters:
    - original_images (zarr array): The input array of original images with shape (n_channels, height, width).
    - contrast_limits_per_channel (dict): A dictionary where keys are channel names and values are the minimum 
                                          contrast limit values for the corresponding channel.

    Returns:
    - binary_masks_array (zarr array): An array of binary masks with the same shape as original_images, where each pixel 
                                       is set to 1 if its value in original_images exceeds the corresponding channel's 
                                       minimum contrast limit, otherwise set to 0.
    """
    binary_masks_array = zarr.zeros_like(original_images)  
    
    # Loop over the channels
    for i in range(original_images.shape[0]):  
        channel_name = list(contrast_limits_per_channel.keys())[i]
        min_val = contrast_limits_per_channel[channel_name]

        print(min_val)
        # Create binary mask using the minimum value as threshold
        binary_masks_array[i] = np.where(original_images[i] > min_val, 1, 0).astype(np.uint16)
    
    return binary_masks_array



def remove_sparse_images(binary_patches, threshold=0.8):
    """
    Remove images from the binary_patches array where the percentage of zeros exceeds a given threshold.
    
    Parameters:
    - binary_patches (numpy array): The input array of binary images with shape (n_batches, n_images, height, width).
    - threshold (float): The threshold for the percentage of zeros in an image to determine if it should be removed.

    Returns:
    - binary_patches_subset (numpy array): The subset of binary_patches array after removing images exceeding the threshold.
    - indices_to_remove (numpy array): The indices of images removed.
    """

    # Calculate the total number of pixels in each image
    total_pixels = binary_patches.shape[2] * binary_patches.shape[3]

    # Calculate the number of zeros in each image
    num_zeros_per_image = np.sum(binary_patches[0] == 0, axis=(1, 2))

    # Calculate the percentage of zeros in each image
    percentage_zeros_per_image = num_zeros_per_image / total_pixels

    # Create a boolean mask indicating images where the percentage of zeros exceeds the threshold
    mask = percentage_zeros_per_image > threshold

    # Find the indices of images where the percentage of zeros exceeds the threshold
    indices_to_remove = np.where(mask)[0]

    # Remove these indices from the array
    binary_patches_subset = np.delete(binary_patches, indices_to_remove, axis=1)

    return binary_patches_subset, indices_to_remove

# Function to drop indices from a Zarr array
def drop_indices_from_zarr(zarr_array, indices_to_drop):
    remaining_indices = np.setdiff1d(np.arange(zarr_array.shape[1]), indices_to_drop)
    return zarr_array[:, remaining_indices, :, :]
    
def add_cluster_column_for_split(df, main_df, clustering_to_append):
    """

    Parameters:
    - df (DataFrame): The split dataframe which will receive the clustering column.
    - main_df (DataFrame): The main dataframe containing all the clustering columns.
    - clustering_to_append (str): The clustering column to append to the split dataframe.

    Returns:
    - DataFrame: The updated split dataframe with the clustering column added.
    """
    # Merge the split dataframe with the main dataframe based on CellID
    df = df.merge(main_df[['CellID', clustering_to_append]], on='CellID', how='left')
    return df


def filter_clusters(df, clustering_column, clusters_to_keep):
    """
    Filter the dataframe to keep only specified clusters and save the dropped indices and CellIDs.

    Parameters:
    - df (DataFrame): The dataframe to be filtered.
    - clustering_column (str): The clustering column to filter on.
    - clusters_to_keep (list): List of clusters to keep.

    Returns:
    - DataFrame: The filtered dataframe.
    - list: The list of indices that were dropped.
    - list: The list of CellIDs of the dropped indices.
    """
    # Identify rows to drop
    mask_to_drop = ~df[clustering_column].isin(clusters_to_keep)
    indices_dropped = df.index[mask_to_drop].tolist()
    cellids_dropped = df.loc[mask_to_drop, 'CellID'].tolist()

    # Drop the rows
    df_filtered = df[~mask_to_drop].reset_index(drop=True)

    return df_filtered, indices_dropped, cellids_dropped

def drop_indices_from_zarr(zarr_array, indices_to_drop):
    """
    Drop the specified indices from the Zarr array.

    Parameters:
    - zarr_array (zarr array): The input Zarr array.
    - indices_to_drop (list): List of indices to drop.

    Returns:
    - zarr array: The updated Zarr array with specified indices dropped.
    """
    remaining_indices = np.setdiff1d(np.arange(zarr_array.shape[1]), indices_to_drop)
    return zarr_array[:, remaining_indices, :, :], remaining_indices

def zip_zarr_directory(directory_path, zip_file_path):
    """
    Zip the contents of the given Zarr directory into a ZIP file.

    Parameters:
    - directory_path (str): The path to the Zarr directory.
    - zip_file_path (str): The path to the output ZIP file.
    """
    dir_path = pathlib.Path(directory_path)
    with zipfile.ZipFile(zip_file_path, "w", compression=zipfile.ZIP_STORED) as zf:
        for f in dir_path.rglob("*"):
            zf.write(f, f.relative_to(dir_path))

In [3]:
# specify VAE clustering of interest and associated window size and latent dimension
vae_output_dir = '/Users/hitsloaner/Downloads/VAE9_VIG7'   #SPECIFY PATH OF YOUR VAE WINDOW
clustering = 'VAE9_VIG7'
window_size = 14  # in pixels
latent_dim = 850

# Define the contrast limits for each channel
contrast_limits = {
    'Keratin_570': 9500,  # Channel 0 in z_subset
    'Ecad_488': 5000,  # Channel 1 in z_subset
    'PCNA_488': 8000.0  # Channel 2 in z_subset
}

# load the original zarr patches for that window size
directory = f'{vae_output_dir}/2_cellcutter_output_win{window_size}/'

main_csv = '/Users/hitsloaner/Downloads/main_all_clustering.csv'  # SPECIFY PATH OF YOUR MAIN CLUSTERING FILE

main = pd.read_csv(main_csv)

tumor_clusters_for_analysis = [0, 1, 3, 5, 6, 21]

In [4]:
splits = ['train', 'test', 'validate']

# Sanity check: Compare kept indices between the original CSV files and Zarr files
for split in splits:
    csv_of_split_path = f'{vae_output_dir}/1_cellcutter_input/{split}.csv'
    csv_file = pd.read_csv(csv_of_split_path)

    # Add the clustering column to the split CSV file
    csv_file_with_clustering_column = add_cluster_column_for_split(csv_file, main, clustering)

    # Initial indices of the CSV file before any filtering
    initial_indices_csv = set(csv_file_with_clustering_column.index)

    # Load the regular file
    regular_file = f"{split}_thumbnails_{window_size}.zip"
    regular_path = os.path.join(directory, regular_file)

    if os.path.exists(regular_path):
        z_store = zarr.ZipStore(regular_path, mode='r')
        z = zarr.open(store=z_store)
        
        initial_indices_zarr = set(np.arange(z.shape[1]))

        # Check if the initial indices are the same
        assert initial_indices_csv == initial_indices_zarr, f"Initial indices mismatch in {split} split"

print("Sanity check passed: Initial indices match between CSV files and Zarr files.")

Sanity check passed: Initial indices match between CSV files and Zarr files.


In [5]:
z.shape

(3, 16267, 14, 14)

In [5]:
# Lists to track all removed indices
all_removed_indices_from_csv = []
all_removed_indices_from_zarr = []
all_kept_indices = []

# Loop over each split and process
for split in splits:

    # Preprocess CSV files
    csv_of_split_path = f'{vae_output_dir}/1_cellcutter_input/{split}.csv'
    csv_file = pd.read_csv(csv_of_split_path)

    # Create a backup of the original CSV file
    csv_backup_path = f'{vae_output_dir}/1_cellcutter_input/{split}_backup.csv'
    shutil.copy(csv_of_split_path, csv_backup_path)

    # Add the clustering column to the split CSV file
    csv_file_with_clustering_column = add_cluster_column_for_split(csv_file, main, clustering)
    
    # Filter the dataframe to keep only specified clusters and save dropped indices and CellIDs
    filtered_csv_file, indices_dropped_from_csv, cellids_dropped = filter_clusters(csv_file_with_clustering_column, clustering, tumor_clusters_for_analysis)
    
    # Print the number of dropped indices and CellIDs
    print(f"{split} - Indices dropped after clustering filter:", len(indices_dropped_from_csv))
    print(f"{split} - CellIDs dropped after clustering filter:", len(cellids_dropped))

    # Add to the list of all removed indices from the CSV
    all_removed_indices_from_csv.extend(indices_dropped_from_csv)

    # DO ZARR PROCESSING
    regular_file = f"{split}_thumbnails_{window_size}.zip"
    seg_file = f"{split}_thumbnails_{window_size}_seg.zip"
    
    # Construct the full paths
    regular_path = os.path.join(directory, regular_file)
    seg_path = os.path.join(directory, seg_file)
    
    # Load the regular file
    if os.path.exists(regular_path):
        z_store = zarr.ZipStore(regular_path, mode='r')
        z = zarr.open(store=z_store)
        
        # Drop the identified indices from the Zarr array
        z_tumor_only, remaining_indices_after_filter = drop_indices_from_zarr(z, indices_dropped_from_csv)

        # Subset channels of interest (i.e. we are interested in tumor relevant markers Keratin, ECAD, and PCNA)
        channels = [2, 3, 19]
        z_subset = z_tumor_only[channels, :, :, :]

        # Binarize and filter the Zarr patches to remove sparse areas
        z_binary = create_binary_masks_with_clipped_limits(z_subset, contrast_limits) 
        z_filtered, indices_to_remove_from_zarr = remove_sparse_images(z_binary, threshold=0.8)

        # Add to the list of all removed indices from the Zarr array
        all_removed_indices_from_zarr.extend(indices_to_remove_from_zarr)

        # Further drop these indices from the filtered CSV file
        filtered_csv_file = filtered_csv_file.drop(indices_to_remove_from_zarr).reset_index(drop=True)

        # Print the number of additional dropped indices
        print(f"{split} - Additional indices dropped after binarization:", len(indices_to_remove_from_zarr))

        # Save the filtered regular Zarr array
        zarr_filtered_path = os.path.join(directory, f'{split}_thumbnails_{window_size}')
        zarr_filtered_store = zarr.DirectoryStore(zarr_filtered_path)
        zarr_filtered = zarr.open(store=zarr_filtered_store, mode='w', shape=z_filtered.shape, dtype=z_filtered.dtype)
        zarr_filtered[:] = z_filtered

        # Zip the filtered regular Zarr array
        zip_zarr_directory(zarr_filtered_path, f'{directory}/{split}_thumbnails_{window_size}.zip')
        shutil.rmtree(zarr_filtered_path)  # Remove the unzipped directory after zipping

        # Drop the same indices from the _seg Zarr array
        if os.path.exists(seg_path):
            seg_store = zarr.ZipStore(seg_path, mode='r')
            seg = zarr.open(store=seg_store)
            
            # Drop indices from seg array based on initial tumor filtering
            seg_tumor_only, _ = drop_indices_from_zarr(seg, indices_dropped_from_csv)
            
            # Drop indices from seg array based on further filtering
            seg_filtered, _ = drop_indices_from_zarr(seg_tumor_only, indices_to_remove_from_zarr)

            # Save the filtered seg Zarr array
            seg_filtered_path = os.path.join(directory, f'{split}_thumbnails_{window_size}_seg')
            seg_filtered_store = zarr.DirectoryStore(seg_filtered_path)
            zarr_seg_filtered = zarr.open(store=seg_filtered_store, mode='w', shape=seg_filtered.shape, dtype=seg_filtered.dtype)
            zarr_seg_filtered[:] = seg_filtered

            # Zip the filtered seg Zarr array
            zip_zarr_directory(seg_filtered_path, f'{directory}/{split}_thumbnails_{window_size}_seg.zip')
            shutil.rmtree(seg_filtered_path)  # Remove the unzipped directory after zipping

        # Track the kept indices
        final_kept_indices = np.setdiff1d(remaining_indices_after_filter, indices_to_remove_from_zarr)
        all_kept_indices.extend(final_kept_indices)

    print("Checking if segmentation shape is same as image shape:", seg_filtered.shape[1] == z_filtered.shape[1])

    # Save the final filtered CSV file (replacing the original file)
    filtered_csv_file.to_csv(csv_of_split_path, index=False)

# Convert the lists of all removed and kept indices to numpy arrays and save
all_removed_indices_from_csv = np.array(all_removed_indices_from_csv)
all_removed_indices_from_zarr = np.array(all_removed_indices_from_zarr)
all_kept_indices = np.array(all_kept_indices)

np.save(f'{vae_output_dir}/all_removed_indices_from_csv.npy', all_removed_indices_from_csv)
np.save(f'{vae_output_dir}/all_removed_indices_from_zarr.npy', all_removed_indices_from_zarr)
np.save(f'{vae_output_dir}/all_kept_indices.npy', all_kept_indices)

print("Processing completed. Indices saved.")

train - Indices dropped after clustering filter: 238718
train - CellIDs dropped after clustering filter: 238718
9500
5000
8000.0
train - Additional indices dropped after binarization: 14403
Checking if segmentation shape is same as image shape: True
test - Indices dropped after clustering filter: 29674
test - CellIDs dropped after clustering filter: 29674
9500
5000
8000.0
test - Additional indices dropped after binarization: 1776
Checking if segmentation shape is same as image shape: True
validate - Indices dropped after clustering filter: 29769
validate - CellIDs dropped after clustering filter: 29769
9500
5000
8000.0
validate - Additional indices dropped after binarization: 1874
Checking if segmentation shape is same as image shape: True
Processing completed. Indices saved.
