# Create patch dataset from real data

**Author**: Prisca Dotti  
**Last modified**: 14.06.24

In [1]:
# autoreload is used to reload modules automatically before entering the
# execution of code typed at the IPython prompt.
%load_ext autoreload
%autoreload 2
# To import modules from parent directory in Jupyter Notebook
import sys

sys.path.append("..")

In [2]:
import os
import numpy as np
import napari
from config import config
from utils.in_out_tools import load_annotations_ids, load_movies_ids
from data.data_processing_tools import detect_spark_peaks

In [3]:
# select samples to include in dataset
sample_ids = [
    "01",
    # "02",
    # "03",
    # "04",
    "05",
    # "06",
    # "07",
    # "08",
    # "09",
    # "10",
    # "11",
    # "12",
    # "13",
    # "14",
    # "15",
    # "16",
    # "17",
    # "18",
    # "19",
    # "20",
    # "21",
    # "22",
    # "23",
    # "24",
    # "25",
    # "27",
    # "28",
    # "29",
    # "30",
    # "32",
    # "33",
    "34",
    # "35",
    # "36",
    # "38",
    # "39",
    # "40",
    # "41",
    # "42",
    # "43",
    # "44",
    # "45",
    # "46",
]

raw_data_dir = os.path.join("..", "data", "sparks_dataset")
out_dir = "patch_dataset"
os.makedirs(out_dir, exist_ok=True)

patch_shape = (64, 32, 32)

In [4]:
def fits_in_patch(spark, patch_start, patch_shape):
    """
    Check if a spark coordinate fits within the bounds of a given patch.

    Parameters:
    spark (tuple): The (t, y, x) coordinates of the spark.
    patch_start (tuple): The starting (t_start, y_start, x_start)
        coordinates of the patch.
    patch_shape (tuple): The shape (t_patch, y_patch, x_patch) of the patch.

    Returns:
    bool: True if the spark fits within the patch, False otherwise.
    """
    t_patch, y_patch, x_patch = patch_shape
    t, y, x = spark
    t_start, y_start, x_start = patch_start
    t_end = t_start + t_patch
    y_end = y_start + y_patch
    x_end = x_start + x_patch
    return t_start <= t < t_end and y_start <= y < y_end and x_start <= x < x_end


def create_patches_from_sparks(sparks_coord, video_shape, patch_shape):
    """
    Create patches from spark coordinates, ensuring each spark is contained in
    exactly one patch, and maximizing the number of patches.

    Parameters:
    sparks_coord (list): List of (t, y, x) coordinates of the sparks.
    video_shape (tuple): The shape (t, y, x) of the video.
    patch_shape (tuple): The shape (t_patch, y_patch, x_patch) of the patches.

    Returns:
    list: A list of patch start coordinates (t_start, y_start, x_start).
    """
    t_patch, y_patch, x_patch = patch_shape
    patches = []
    assigned_sparks = set()

    def find_patch_for_spark(spark):
        """
        Find an existing patch that can contain the given spark.

        Parameters:
        spark (tuple): The (t, y, x) coordinates of the spark.

        Returns:
        tuple or None: The starting coordinates (t_start, y_start, x_start) of
        the patch if found, otherwise None.
        """
        for patch_start in patches:
            if fits_in_patch(spark, patch_start, patch_shape):
                return patch_start
        return None

    def is_spark_assigned(spark):
        """
        Check if a spark has already been assigned to a patch.

        Parameters:
        spark (tuple): The (t, y, x) coordinates of the spark.

        Returns:
        bool: True if the spark is already assigned, False otherwise.
        """
        return spark in assigned_sparks

    for spark in sparks_coord:
        if is_spark_assigned(spark):
            continue

        patch_start = find_patch_for_spark(spark)
        if patch_start is None:
            t, y, x = spark
            t_start = max(0, t - t_patch // 2)
            y_start = max(0, y - y_patch // 2)
            x_start = max(0, x - x_patch // 2)

            if t_start + t_patch > video_shape[0]:
                t_start = video_shape[0] - t_patch
            if y_start + y_patch > video_shape[1]:
                y_start = video_shape[1] - y_patch
            if x_start + x_patch > video_shape[2]:
                x_start = video_shape[2] - x_patch

            patch_start = (t_start, y_start, x_start)
            patches.append(patch_start)

        # Mark all sparks in the current patch as assigned
        t_start, y_start, x_start = patch_start
        t_end = t_start + t_patch
        y_end = y_start + y_patch
        x_end = x_start + x_patch

        for t, y, x in sparks_coord:
            if t_start <= t < t_end and y_start <= y < y_end and x_start <= x < x_end:
                assigned_sparks.add((t, y, x))

    return patches

In [5]:
patches_dict = {}
for sample_id in sample_ids:
    print(f"Extracting patches from sample {sample_id}...")

    video = load_movies_ids(
        data_folder=raw_data_dir,
        ids=[sample_id],
        names_available=True,
        movie_names="video",
    )[sample_id]
    print(f"  Loaded video with shape: {video.shape}")

    class_label = load_annotations_ids(
        data_folder=raw_data_dir, ids=[sample_id], mask_names="class_label"
    )[sample_id]
    events_in_sample = np.unique(class_label)
    events_in_sample = events_in_sample[events_in_sample != 0]
    print(f"  Type of events in sample {sample_id}: {events_in_sample}")

    event_label = load_annotations_ids(
        data_folder=raw_data_dir, ids=[sample_id], mask_names="event_label"
    )[sample_id]

    sparks_event_label = np.where(class_label == 1, event_label, 0)

    sparks_coord = detect_spark_peaks(
        movie=video,
        instances_mask=sparks_event_label,
        sigma=config.sparks_sigma_dataset,
        max_filter_size=10,
    )

    # Create patches from the spark coordinates
    patches = create_patches_from_sparks(
        sparks_coord=sparks_coord, video_shape=video.shape, patch_shape=patch_shape
    )

    # Filter out patches that contain too much puff or wave pixels
    for patch_start in patches:
        t_start, y_start, x_start = patch_start
        t_end = t_start + patch_shape[0]
        y_end = y_start + patch_shape[1]
        x_end = x_start + patch_shape[2]
        puff_pixels = np.sum(
            class_label[t_start:t_end, y_start:y_end, x_start:x_end] == 2
        )
        wave_pixels = np.sum(
            class_label[t_start:t_end, y_start:y_end, x_start:x_end] == 3
        )
        if (
            puff_pixels > 0.1 * patch_shape[1] * patch_shape[2]
            or wave_pixels > 0.1 * patch_shape[1] * patch_shape[2]
        ):
            print(f"  Removing patch {patch_start} due to puff or wave pixels")
            patches.remove(patch_start)

    # Output information about the patches
    for i, patch_start in enumerate(patches):
        t_start, y_start, x_start = patch_start
        t_end = t_start + patch_shape[0]
        y_end = y_start + patch_shape[1]
        x_end = x_start + patch_shape[2]
        print(f"  Patch {i+1}:")
        print(
            f"    T: ({t_start}, {t_end}), Y: ({y_start}, {y_end}), X: ({x_start}, {x_end})"
        )
        # peaks_in_patch = [
        #     spark
        #     for spark in sparks_coord
        #     if fits_in_patch(spark, patch_start, patch_shape)
        # ]
        # classes_in_patch = np.unique(
        #     class_label[t_start:t_end, y_start:y_end, x_start:x_end]
        # )
        # classes_in_patch = classes_in_patch[classes_in_patch != 0]
        # sparks_in_patch = np.unique(
        #     sparks_event_label[t_start:t_end, y_start:y_end, x_start:x_end]
        # )
        # sparks_in_patch = sparks_in_patch[sparks_in_patch != 0]
        # print(f"    Classes in patch: {classes_in_patch}")
        # print(f"    Sparks in patch: {sparks_in_patch}")
        # print(f"    Peaks in patch:")
        # for p in peaks_in_patch:
        #     print(f"      {p} (event ID: {event_label[p]})")

    patches_dict[sample_id] = patches

Extracting patches from sample 01...
  Loaded video with shape: (500, 64, 512)
  Type of events in sample 01: [1 3 4]
  Removing patch (37, 5, 187) due to puff or wave pixels
  Removing patch (198, 3, 195) due to puff or wave pixels
  Removing patch (286, 0, 192) due to puff or wave pixels
  Removing patch (243, 0, 213) due to puff or wave pixels
  Patch 1:
    T: (48, 112), Y: (30, 62), X: (227, 259)
  Patch 2:
    T: (143, 207), Y: (29, 61), X: (246, 278)
  Patch 3:
    T: (336, 400), Y: (8, 40), X: (219, 251)
  Patch 4:
    T: (1, 65), Y: (10, 42), X: (166, 198)
  Patch 5:
    T: (163, 227), Y: (9, 41), X: (181, 213)
  Patch 6:
    T: (202, 266), Y: (19, 51), X: (146, 178)
  Patch 7:
    T: (436, 500), Y: (23, 55), X: (171, 203)
  Patch 8:
    T: (33, 97), Y: (10, 42), X: (166, 198)
  Patch 9:
    T: (29, 93), Y: (0, 32), X: (212, 244)
Extracting patches from sample 05...
  Loaded video with shape: (500, 64, 512)
  Type of events in sample 05: [1 3 4]


TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'
TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'


  Removing patch (75, 0, 102) due to puff or wave pixels
  Removing patch (47, 0, 178) due to puff or wave pixels
  Removing patch (322, 29, 147) due to puff or wave pixels
  Removing patch (164, 6, 182) due to puff or wave pixels
  Removing patch (374, 32, 154) due to puff or wave pixels
  Removing patch (6, 31, 51) due to puff or wave pixels
  Removing patch (196, 32, 154) due to puff or wave pixels
  Removing patch (271, 5, 122) due to puff or wave pixels
  Removing patch (92, 10, 69) due to puff or wave pixels
  Removing patch (9, 22, 73) due to puff or wave pixels
  Removing patch (267, 20, 187) due to puff or wave pixels
  Patch 1:
    T: (164, 228), Y: (25, 57), X: (287, 319)
  Patch 2:
    T: (154, 218), Y: (32, 64), X: (126, 158)
  Patch 3:
    T: (26, 90), Y: (9, 41), X: (92, 124)
  Patch 4:
    T: (45, 109), Y: (25, 57), X: (86, 118)
  Patch 5:
    T: (251, 315), Y: (4, 36), X: (320, 352)
  Patch 6:
    T: (63, 127), Y: (32, 64), X: (127, 159)
  Patch 7:
    T: (339, 403), Y

In [6]:
n_patches = sum(len(patches) for patches in patches_dict.values())
print(f"Total number of patches: {n_patches}")

Total number of patches: 58


In [13]:
# Visualize selected patches using napari

sample_id = "34"

video = load_movies_ids(
    data_folder=raw_data_dir,
    ids=[sample_id],
    names_available=True,
    movie_names="video",
)[sample_id]

class_label = load_annotations_ids(
    data_folder=raw_data_dir, ids=[sample_id], mask_names="class_label"
)[sample_id]

event_label = load_annotations_ids(
    data_folder=raw_data_dir, ids=[sample_id], mask_names="event_label"
)[sample_id]

sparks_event_label = np.where(class_label == 1, event_label, 0)

sparks_coord = detect_spark_peaks(
    movie=video,
    instances_mask=sparks_event_label,
    sigma=config.sparks_sigma_dataset,
    max_filter_size=10,
)

patches = patches_dict[sample_id]
print(f"Number of patches for sample {sample_id}: {len(patches)}")

TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'
TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'


Number of patches for sample 34: 13


In [14]:
for patch_id, (t_start, y_start, x_start) in enumerate(patches):

    # patch_id = 45 + patch_id

    t_patch, y_patch, x_patch = patch_shape
    t_end = t_start + t_patch
    y_end = y_start + y_patch
    x_end = x_start + x_patch

    peaks_in_patch = [
        (t - t_start, y - y_start, x - x_start)
        for (t, y, x) in sparks_coord
        if fits_in_patch((t, y, x), patches[patch_id], patch_shape)
    ]

    viewer = napari.Viewer()
    viewer.add_image(video[t_start:t_end, y_start:y_end, x_start:x_end], name="video")
    viewer.add_labels(
        class_label[t_start:t_end, y_start:y_end, x_start:x_end],
        name="class_label",
        opacity=0.5,
    )
    # viewer.add_labels(
    #     event_label[t_start:t_end, y_start:y_end, x_start:x_end],
    #     name="event_label",
    #     opacity=0.5,
    #     visible=False
    #     )
    # viewer.add_labels(
    #     sparks_event_label[t_start:t_end],
    #     name="sparks_event_label",
    #     opacity=0.5,
    #     visible=False
    #     )
    viewer.add_points(peaks_in_patch, name="peaks_in_patch", face_color="red", size=2)
    viewer.dims.current_step = (0, y_patch // 2, x_patch // 2)