This notebook converts label values from the tracked and registered labels back to the original labels so that the original segmentations can be used for training TrackAstra

In [1]:
from dataclasses import dataclass
import glob
import os
import pickle

import networkx as nx
import numpy as np
from scipy import stats
from skimage.io import imread

In [2]:
def load_images(
    directory_path: os.PathLike,
    name_pattern: str
) -> np.ndarray:
    n_files = len(glob.glob(os.path.join(directory_path, "*.tif")))

    images = []
    for t_index in range(n_files):
        file_name = name_pattern.format(t_index)
        file_path = os.path.join(directory_path, file_name)
        images.append(imread(file_path))

    return np.stack(images)

In [3]:
@dataclass
class DataToConvert:
    tracked_segmentation_directory: str
    registered_segmentation_directory: str
    graph_path: str
    base_name: str

In [4]:
all_datasets = [
    DataToConvert(
        tracked_segmentation_directory = "/local1/early_embryo/mzg_20240424/track_pre_processed/relabeled_segmentation",
        registered_segmentation_directory = "/local1/early_embryo/mzg_20240424/registered",
        graph_path = "/nas/groups/iber/Projects/Embryo_parameter_estimation/old/process_all_opticell3d_20240812/track_old/embryo_3/curated_graph.pkl",
        base_name = "embryo3"
    ),
    DataToConvert(
        tracked_segmentation_directory = "/local1/early_embryo/mzg_20240424/process_all/track_dual/embryo4/relabeled_segmentation",
        registered_segmentation_directory = "/local1/early_embryo/mzg_20240424/process_all/registered_dual/embryo4",
        graph_path = "/local1/early_embryo/mzg_20240424/process_all/track_dual/embryo4/curated_graph.pkl",
        base_name = "embryo4"
    ),
    DataToConvert(
        tracked_segmentation_directory = "/local1/early_embryo/mzg_20240424/process_all/track_dual/embryo5/relabeled_segmentation",
        registered_segmentation_directory = "/local1/early_embryo/mzg_20240424/process_all/registered_dual/embryo5",
        graph_path = "/local1/early_embryo/mzg_20240424/process_all/track_dual/embryo5/curated_graph.pkl",
        base_name = "embryo5"
    ),
    DataToConvert(
        tracked_segmentation_directory = "/local1/early_embryo/mzg_20240424/process_all/track_dual/embryo6/relabeled_segmentation",
        registered_segmentation_directory = "/local1/early_embryo/mzg_20240424/process_all/registered_dual/embryo6",
        graph_path = "/local1/early_embryo/mzg_20240424/process_all/track_dual/embryo6/curated_graph.pkl",
        base_name = "embryo6"
    ),
]

In [5]:
for dataset in all_datasets:
    tracked_segmentation_directory = dataset.tracked_segmentation_directory
    registered_segmentation_directory = dataset.registered_segmentation_directory
    graph_path = dataset.graph_path
    base_name = dataset.base_name
    
    # load the images
    tracked_segmentations = load_images(
            directory_path=tracked_segmentation_directory,
            name_pattern=base_name + "_t_{}_tracked.tif"
    )
    registered_segmentations = load_images(
            directory_path=registered_segmentation_directory,
            name_pattern=base_name + "_t_{}_seg.tif"
    )
    
    # load the graph
    with open(graph_path, "rb") as f:
        graph = pickle.load(f)
    
    # iterate through the nodes
    new_node_data = {}
    for node, data in graph.nodes(data=True):
        tracked_label = data["opticell_label"]
        time_index = int(data["t"])
    
        tracked_time_point = tracked_segmentations[time_index, ...]
        registered_time_point = registered_segmentations[time_index, ...]
    
        registered_labels = registered_time_point[tracked_time_point == tracked_label]
    
        # get the most common label
        mode_results = stats.mode(registered_labels)
        registered_label_value = mode_results.mode
        mode_count = mode_results.count
        fraction_overlap = mode_count / len(registered_labels)
        if fraction_overlap < 0.95:
            print(f"{base_name} t {time_index}, label {tracked_label} doesn't match: {fraction_overlap}")
    
        # add the new label to the node data
        data["original_label"] = registered_label_value
        new_node_data[node] = data
    
    # update and save the graph
    nx.set_node_attributes(graph, new_node_data)
    with open(f"{base_name}_curated_graph.pkl", "wb") as f:
        pickle.dump(graph, f)

embryo3 t 271, label 11 doesn't match: 0.9297030432539974
embryo4 t 65, label 2 doesn't match: 0.8636196040614896
embryo5 t 55, label 2 doesn't match: 0.8684102709266033
embryo5 t 261, label 4 doesn't match: 0.7172352818805908
embryo5 t 262, label 4 doesn't match: 0.6803183977404031
embryo5 t 263, label 4 doesn't match: 0.6173089233132761
embryo5 t 264, label 4 doesn't match: 0.6427324088341038
embryo5 t 265, label 4 doesn't match: 0.7205934839137307
embryo5 t 268, label 4 doesn't match: 0.8591757770910696
embryo5 t 271, label 4 doesn't match: 0.7264249231316545
embryo5 t 232, label 4 doesn't match: 0.7267983600447261
embryo5 t 255, label 12 doesn't match: 0.8530189062817646
embryo5 t 260, label 12 doesn't match: 0.9332423609996254
embryo5 t 270, label 15 doesn't match: 0.8365636927480916
embryo6 t 250, label 7 doesn't match: 0.7433085184974711
embryo6 t 240, label 8 doesn't match: 0.8201334850675115
embryo6 t 273, label 8 doesn't match: 0.6648003445517751
