In [1]:
import os
import shutil
import pathlib
import zipfile

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import zarr

parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))

In [2]:
# Processed the output of this code through VAE pipeline on o2

In [3]:
def binarize_patches(patches, contrast_limits):
    """
    Create binary masks from original images based on specified contrast limits for each channel.

    Parameters:
    - original_patches (zarr array): The input array of original images with shape (n_channels, height, width).
    - contrast_limits (dict): A dictionary where keys are channel names and values are tuples of channel 
                              number and minimum contrast limit values for the corresponding channel.

    Returns:
    - binary_patches (zarr array): An array of binary masks with the same shape as original_images, 
                                   where each pixel is set to 1 if its value in original_images exceeds 
                                   the corresponding channel's minimum contrast limit, otherwise set to 0.
    """
    channels = [i[0] for i in contrast_limits.values()]
    patches = patches[channels, :, :, :]
    binary_patches = zarr.zeros_like(patches)

    # loop over channels
    limits = [i[1] for i in contrast_limits.values()]
    for e, limit in enumerate(limits):
        # binarize patches according to threshold value
        binary_patches[e] = np.where(patches[e] > limit, 1, 0).astype(np.uint16)
    
    return binary_patches

In [4]:
def filter_low_quality_patches(binary_patches, threshold=0.8):
    """
    Remove patches from the binary_patches array if the percentage of zeros 
    in any channel exceeds the given threshold.
    
    Parameters:
    - binary_patches (numpy array): The input array of binary images with shape 
      (n_channels, n_patches, height, width).
    - threshold (float): The threshold for the percentage of zeros in a patch above 
      which if will be removed.

    Returns:
    - binary_patches_subset (numpy array): The subset of binary_patches array 
      after removing patches exceeding the threshold.
    - indices_to_remove (numpy array): The indices of patches removed.
    """

    # calculate total number of pixels in each patch
    total_pixels = binary_patches.shape[2] * binary_patches.shape[3]

    # create zero mask (True for zero, False otherwise)
    zero_mask = (binary_patches[:] == 0)
    
    # sum across height (axis 2) and width (axis 3) of each image patch channel to get the zeros count
    num_zeros_per_channel_patch = np.sum(zero_mask, axis=(2, 3))

    # calculate percentage of zeros in each patch for all channels
    percentage_zeros_per_patch = num_zeros_per_channel_patch / total_pixels

    # create Boolean mask for each patch: True if any channel has percentage of zeros exceeding threshold
    # mask = np.any(percentage_zeros_per_patch > threshold, axis=0)  # work on all channels
    mask = percentage_zeros_per_patch[0] > threshold  # work on first channel only
    
    # locate indices of patches where any channel exceeds the threshold
    indices_to_remove = np.where(mask)[0]

    # remove patches from the array across all channels
    binary_patches_subset = np.delete(binary_patches, indices_to_remove, axis=1)

    return binary_patches_subset, indices_to_remove

In [5]:
def add_cluster_labels(subset, main, clustering):
    """

    Parameters:
    - subset (DataFrame): The split dataframe which will receive the clustering column.
    - main (DataFrame): The main dataframe containing all the clustering columns.
    - clustering (str): The clustering column to append to the split dataframe.

    Returns:
    - DataFrame: The updated split dataframe with the clustering column added.
    """
    # merge subset dataframe with main (i.e. full) dataframe based on CellID
    subset = subset.merge(main[['CellID', clustering]], on='CellID', how='left')
    return subset

In [6]:
def filter_clusters(subset, clustering, clusters_to_keep):
    """
    Filter the dataframe to keep only specified clusters and save the dropped indices and CellIDs.

    Parameters:
    - subset (DataFrame): The dataframe to be filtered.
    - clustering (str): The clustering column to filter on.
    - clusters_to_keep (list): List of clusters to keep.

    Returns:
    - DataFrame: The filtered dataframe.
    - list: The list of indices that were dropped.
    - list: The list of CellIDs of the dropped indices.
    """
    # identify rows to drop
    mask_to_drop = ~subset[clustering].isin(clusters_to_keep)
    indices_to_drop = subset.index[mask_to_drop].tolist()

    # filter data subset
    filtered_subset = subset[~mask_to_drop].reset_index(drop=True)

    return filtered_subset, indices_to_drop

In [7]:
def drop_indices_from_zarr(z, indices_to_drop):
    """
    Drop the specified indices from a zarr array.

    Parameters:
    - z (zarr array): The input zarr array.
    - indices_to_drop (list): List of indices to drop.

    Returns:
    - z: The updated zarr array with specified indices dropped.
    """
    remaining_indices = np.setdiff1d(np.arange(z.shape[1]), indices_to_drop)
    
    if isinstance(z, zarr.core.Array):
        return z.oindex[:, remaining_indices, :, :]
    else:
        return z[:, remaining_indices, :, :]

In [8]:
def zip_zarr_directory(directory_path, zip_file_path):
    """
    Zip the contents of the given Zarr directory into a ZIP file.

    Parameters:
    - directory_path (str): The path to the Zarr directory.
    - zip_file_path (str): The path to the output ZIP file.
    """
    dir_path = pathlib.Path(directory_path)
    with zipfile.ZipFile(zip_file_path, "w", compression=zipfile.ZIP_STORED) as zf:
        for f in dir_path.rglob("*"):
            zf.write(f, f.relative_to(dir_path))

In [9]:
# specify VAE clustering of interest and corresponding window size
tumor_clusters = [0, 1, 3, 5, 6, 21, 22]
clustering = 'VAE9_VIG7' # column header in main.csv
window_size = 14  # in pixels

# channel contrast limits
contrast_limits = {
    'Keratin_570': (2, 9500), # channel 0 in z_binary orig: 
    # 'Ecad_488': (3, 5000), # channel 1 in z_binary
    # 'PCNA_488': (19, 8000) # channel 2 in z_binary
}

# load dataframe containing cluster labels
main_csv = os.path.join(parent_dir, 'input/main.csv')
main = pd.read_csv(main_csv)

# out dir
out = os.path.join(parent_dir, 'output/binary/binarized_patches')
if not os.path.exists(out):
    os.makedirs(out)

In [10]:
data_subsets = ['train', 'test', 'validate']

# loop over each data subset and process
for subset in data_subsets:
    
    print(f'Working on {subset} dataset...')
    
    # read CSV file
    data = pd.read_csv(os.path.join(parent_dir, f'input/VAE9_VIG7/1_cellcutter_input/{subset}.csv'))

    # add clustering column to data subset
    data_clustered = add_cluster_labels(subset=data, main=main, clustering=clustering)
    
    # filter subset to keep only tumor clusters, save dropped indices
    data_filtered, dropped_indices = filter_clusters(
        subset=data_clustered, clustering=clustering, clusters_to_keep=tumor_clusters
    )
    print(f'{len(dropped_indices)} indices dropped after cluster filtering {subset} dataset')
    
    # drop indices from the corresponding zarr file
    zarr_path = (
        os.path.join(parent_dir, 
                     f'input/VAE9_VIG7/2_cellcutter_output_win14/{subset}_thumbnails_{window_size}.zip')
    )
    z_store = zarr.ZipStore(zarr_path, mode='r')
    z = zarr.open(store=z_store)
    z_tumor = drop_indices_from_zarr(z=z, indices_to_drop=dropped_indices)

    # binarize zarr patches
    z_binary = binarize_patches(patches=z_tumor, contrast_limits=contrast_limits) 

    # drop low quality binarized patches from zarr and CSV file
    z_filtered, low_quality_indices = filter_low_quality_patches(binary_patches=z_binary, threshold=0.8)
    data_filtered = data_filtered.drop(low_quality_indices).reset_index(drop=True)
    print(f'{len(low_quality_indices)} additional indices dropped from {subset} dataset after binarization')

    # remove patch outliers
    artifact_cellids = pd.read_csv(os.path.join(os.getcwd(), 'artifact_cellids.csv')) 
    artifact_indices = data_filtered.index[data_filtered['CellID'].isin(artifact_cellids['CellID'])]
    data_filtered = data_filtered.drop(artifact_indices).reset_index(drop=True)
    z_filtered = drop_indices_from_zarr(z=z_filtered, indices_to_drop=artifact_indices)
    print(f'{len(artifact_indices)} additional indices dropped from {subset} dataset after removing outliers')

    # save filtered data subset file
    data_filtered.to_csv(os.path.join(out, f'{subset}.csv'), index=False)
    
    # save filtered zarr array and zip
    zarr_filtered_path = os.path.join(out, f'{subset}_thumbnails_{window_size}')
    zarr_filtered_store = zarr.DirectoryStore(zarr_filtered_path)
    zarr_filtered = zarr.open(
        store=zarr_filtered_store, mode='w', shape=z_filtered.shape, dtype=z_filtered.dtype
    )
    zarr_filtered[:] = z_filtered
    zip_zarr_directory(zarr_filtered_path, os.path.join(out, f'{subset}_thumbnails_{window_size}.zip'))
    shutil.rmtree(zarr_filtered_path)  # remove unzipped directory after zipping

    # drop indices from segmentation zarr
    seg_zarr_path = (
        os.path.join(parent_dir, 
                     f'input/VAE9_VIG7/2_cellcutter_output_win14/{subset}_thumbnails_{window_size}_seg.zip')
    )
    seg_store = zarr.ZipStore(seg_zarr_path, mode='r')
    z_seg = zarr.open(store=seg_store)
    seg_tumor_only = drop_indices_from_zarr(z=z_seg, indices_to_drop=dropped_indices)
    seg_filtered = drop_indices_from_zarr(z=seg_tumor_only, indices_to_drop=low_quality_indices)
    seg_filtered = drop_indices_from_zarr(z=seg_filtered, indices_to_drop=artifact_indices)

    # save filtered seg zarr and zip
    seg_filtered_path = os.path.join(out, f'{subset}_thumbnails_{window_size}_seg')
    seg_filtered_store = zarr.DirectoryStore(seg_filtered_path)
    zarr_seg_filtered = zarr.open(
        store=seg_filtered_store, mode='w', shape=seg_filtered.shape, dtype=seg_filtered.dtype
    )
    zarr_seg_filtered[:] = seg_filtered
    zip_zarr_directory(seg_filtered_path, os.path.join(out, f'{subset}_thumbnails_{window_size}_seg.zip'))
    shutil.rmtree(seg_filtered_path)  # remove unzipped directory after zipping

    print(
        'Checking if csv, zarr, and seg_zarr are same length:', 
        len(data_filtered) == seg_filtered.shape[1] == z_filtered.shape[1]
    )
    print()

print('Image patch filteration and binarization complete.')

Working on train dataset...
234670 indices dropped after cluster filtering train dataset
14913 additional indices dropped from train dataset after binarization
5661 additional indices dropped from train dataset after removing outliers
Checking if csv, zarr, and seg_zarr are same length: True

Working on test dataset...
29151 indices dropped after cluster filtering test dataset
1833 additional indices dropped from test dataset after binarization
735 additional indices dropped from test dataset after removing outliers
Checking if csv, zarr, and seg_zarr are same length: True

Working on validate dataset...
29298 indices dropped after cluster filtering validate dataset
1931 additional indices dropped from validate dataset after binarization
701 additional indices dropped from validate dataset after removing outliers
Checking if csv, zarr, and seg_zarr are same length: True

Image patch filteration and binarization complete.
