### Tracking with btrack

In [7]:
### Imports
# Standard library imports
import json
import os
import shutil
import logging
import sys

# Third-party imports
# Data handling
import numpy as np
import pandas as pd

# Image I/O and processing
import tifffile as tiff
from nd2reader import ND2Reader
from skimage.morphology import remove_small_objects

# Deep learning and segmentation
from csbdeep.utils import Path, normalize
from stardist import (
    fill_label_holes,
    random_label_cmap,
    calculate_extents,
    gputools_available,
)
from stardist.matching import matching, matching_dataset
from stardist.models import Config2D, StarDist2D, StarDistData2D
from tensorflow.keras.utils import Sequence

# Tracking
import btrack
from btrack.constants import BayesianUpdates

# Visualization
import matplotlib.cm as cm
import matplotlib.pyplot as plt

# Utilities
from tqdm import tqdm

from datetime import datetime

In [8]:
## Variables
## Directory Paths
# Input
IMG_DIR = '/mnt/imaging.data/PertzLab/apoDetection/TIFFs'
APO_DIR = '/mnt/imaging.data/PertzLab/apoDetection/ApoptosisAnnotation'
EXPERIMENT_INFO = '/mnt/imaging.data/PertzLab/apoDetection/List of the experiments.csv'
MASK_DIR = '../data/apo_masks'    # Stardist label predictions

# Output
TRACKED_MASK_DIR = '../data/tracked_masks'
CSV_DIR = '../data/apo_match_csv'    # File with manual and stardist centroids
DF_DIR = '../data/summary_dfs'
TRACK_DF_DIR = '../data/track_dfs'
CROPS_DIR = '../data/apo_crops_test'    # Directory with .tif files for QC
WINDOWS_DIR = '/home/nbahou/myimaging/apoDet/data/windows_test'    # Directory with crops for scDINO
RANDOM_DIR = os.path.join(WINDOWS_DIR, 'random')
CLASS_DCT_PATH = './extras/class_dicts'


## Processing Configuration
COMPARE_2D_VERS = True
SAVE_MASKS = True
LOAD_MASKS = True
USE_GPU = True
MIN_NUC_SIZE = 200

## Tracking Parameters
BT_CONFIG_FILE = "extras/cell_config.json"  # Path to btrack config file
EPS_TRACK = 70         # Tracking radius [px]
TRK_MIN_LEN = 25       # Minimum track length [frames]

#
MAX_TRACKING_DURATION = 20    # In minutes
FRAME_INTERVAL = 5    # minutes between images we want

WINDOW_SIZE = 61


## Logger Set Up
#logging.shutdown()    # For jupyter notebooks
logger = logging.getLogger(__name__)
#if logger.hasHandlers():
#    logger.handlers.clear()
# Get the current timestamp
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# Define log directory and ensure it exists
log_dir = "./logs"  # Folder for logs
os.makedirs(log_dir, exist_ok=True)  # Create directory if it doesn't exist

log_filename = f"tracking_Btrack_{timestamp}.log"
log_path = os.path.join(log_dir, log_filename)

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_path),
        logging.StreamHandler(sys.stdout)  # Outputs to console too
    ],
    force = True
)

# Create a logger instance
logger = logging.getLogger(__name__)

# Only forward Warnings/Errors/Critical from btrack
logging.getLogger('btrack').setLevel(logging.WARNING)


In [9]:
def load_image_stack(path):
    """
    Load an image stack from a file based on its extension.
    
    Uses tifffile for TIFF files and ND2Reader for ND2 files.
    """
    if path.endswith(('.tif', '.tiff')):
        # Load TIFF file using tifffile
        return tiff.imread(path)
    elif path.endswith('.nd2'):
        # Load ND2 file using ND2Reader and convert it to a numpy array
        with ND2Reader(path) as nd2:
            return np.array(nd2)
    else:
        raise ValueError(f"Unsupported file format for file: {path}")

def get_image_paths(directory):
    """
    Returns a list of absolute paths of all TIFF and ND2 files in a directory.
    """
    valid_extensions = ('.tif', '.tiff', '.nd2')
    paths = [
        os.path.abspath(os.path.join(directory, f))
        for f in os.listdir(directory)
        if f.endswith(valid_extensions)
    ]
    return sorted(paths)

def crop_window(img, center_x, center_y, window_size):
    # Check if number is even, add one if so
    if window_size%2 == 0:
        window_size += 1
        logger.warning(f'\t\tWindow size even, adding 1. New window size: {window_size}')
    half_window_size = window_size // 2
    x_from = max(center_x - half_window_size, 0)
    x_to = min(center_x + half_window_size + 1, img.shape[1])
    y_from = max(center_y - half_window_size, 0)
    y_to = min(center_y + half_window_size + 1, img.shape[0])
    window = img[y_from:y_to, x_from:x_to]

    return window



def run_tracking(gt_filtered, fovX, fovY):
    logging.info("\tStarting tracking")
    btObj = btrack.utils.segmentation_to_objects(gt_filtered, properties=("area",), assign_class_ID=True)
    
    with btrack.BayesianTracker() as tracker:
        tracker.configure(BT_CONFIG_FILE)
        tracker.update_method = BayesianUpdates.APPROXIMATE
        tracker.max_search_radius = EPS_TRACK
        tracker.append(btObj)
        tracker.volume = ((0, fovX), (0, fovY))
        tracker.track(step_size=100)
        tracker.optimize()
        btTracks = tracker.tracks
    
    dfBTracks = pd.concat(pd.DataFrame(t.to_dict(["ID", "t", "x", "y"])) for t in btTracks)
    dfBTracks.rename(columns={"ID": "track_id", "t": "t", "x": "x", "y": "y", "class_id": "obj_id"}, inplace=True)
    dfBTracks["obj_id"] = dfBTracks["obj_id"].astype("Int32")
    logging.info("\t\tTracking Done.")
    return dfBTracks


def convert_obj_to_track_ids(gt_filtered, merged_df):
    """
    Converts object IDs to track IDs in Stardist masks using tracking data.

    Parameters:
        gt_filtered (np.ndarray): The 3D array (time, height, width) of segmentation masks.
        merged_df (pd.DataFrame): DataFrame containing tracking information with 'obj_id', 'track_id', and 't'.

    Returns:
        np.ndarray: New 3D mask where obj_ids are replaced with track_ids.
    """
    logger.info("\tConverting Obj_IDs to Track_IDs in stardist masks.")
    
    tracked_masks = np.zeros_like(gt_filtered)

    for t, mask_frame in enumerate(gt_filtered):
        current_df = merged_df[merged_df['t'] == t]

        # Create a mapping {obj_id: track_id} for this timepoint
        obj_to_track = current_df.set_index('obj_id')['track_id'].to_dict()

        # Replace obj_id in the mask with the corresponding track_id
        for obj_id, track_id in obj_to_track.items():
            tracked_masks[t][mask_frame == obj_id] = track_id

    return tracked_masks

In [10]:
# Load image paths in specified directory
logger.info("Starting Image Processing")
image_paths = get_image_paths(os.path.join(IMG_DIR))
filenames = [os.path.splitext(os.path.basename(path))[0] for path in image_paths[:2]]    ### TODO remove :2 here, was only for testing
logger.info(f"Detected {len(filenames)} files in specified directories.")
#print(filenames)

# Create directories for saving if they do not exist
output_dirs = [MASK_DIR, DF_DIR, TRACK_DF_DIR, TRACKED_MASK_DIR]
for path in output_dirs:
    os.makedirs(path, exist_ok=True)

# Loop over all files in target directory (predict labels, track and crop windows for each)
logger.info("Starting to process files.")
for path, filename in zip(image_paths, filenames):
    # Define and load labels
    mask_path = os.path.join(MASK_DIR, f'{filename}.npz')
    with np.load(mask_path) as data:
        gt_filtered = data['gt']  # Access the saved array
    
    df_path = os.path.join(DF_DIR, f'{filename}_pd_df.csv')
    strdst_df =  pd.read_csv(df_path, header=0)

    # Run tracking with Btrack
    _, fovY, fovX = gt_filtered.shape
    dfBTracks = run_tracking(gt_filtered, fovX, fovY)
    
    logger.info("\tMerging information from Btrack and stardist.")
    merged_df = strdst_df.merge(dfBTracks.drop(columns=["x", "y", "area"]), on=["obj_id", "t"], how="left")
    # Enable next line if you only want tracks which are longer than TRK_MIN_LEN
    # merged_df = merged_df[merged_df.groupby("track_id")["track_id"].transform('size') >= TRK_MIN_LEN].copy()
    logger.info("\t\tComplete.")

    # Save merged DataFrame to a CSV file
    merge_df_path = os.path.join(TRACK_DF_DIR, f"{filename}.csv")
    merged_df.to_csv(merge_df_path, index=False)
    logger.info(f"\tSaved merged dfs at: {merge_df_path}")

    ### Create a mask with btrack track_ids instead of stardists obj_ids
    tracked_masks = convert_obj_to_track_ids(gt_filtered, merged_df)
    
    # Save masks
    mask_path = os.path.join(TRACKED_MASK_DIR, f'{filename}.npz')
    np.savez_compressed(mask_path, gt=tracked_masks)
    logger.info(f"\t\tMask saved at: {mask_path}")
    

2025-03-19 16:52:33,674 - __main__ - INFO - Starting Image Processing
2025-03-19 16:52:33,697 - __main__ - INFO - Detected 2 files in specified directories.
2025-03-19 16:52:33,769 - __main__ - INFO - Starting to process files.
2025-03-19 16:52:37,523 - root - INFO - 	Starting tracking


100%|███████████████████████████████████████████████████████████████████████████████| 1441/1441 [01:37<00:00, 14.74it/s]


GLPK Integer Optimizer 5.0
29796 rows, 28305 columns, 42923 non-zeros
28305 integer variables, all of which are binary
Preprocessing...
14898 rows, 28305 columns, 42923 non-zeros
28305 integer variables, all of which are binary
Scaling...
 A: min|aij| =  1.000e+00  max|aij| =  1.000e+00  ratio =  1.000e+00
Problem data seem to be well scaled
Constructing initial basis...
Size of triangular part is 14898
Solving LP relaxation...
GLPK Simplex Optimizer 5.0
14898 rows, 28305 columns, 42923 non-zeros
*     0: obj =   8.451693242e+04 inf =   0.000e+00 (10678)
Perturbing LP to avoid stalling [511]...
*  8965: obj =   2.818301162e+04 inf =   2.000e-09 (1412) 5
Removing LP perturbation [10417]...
* 10417: obj =   2.708178225e+04 inf =   0.000e+00 (0) 4
OPTIMAL LP SOLUTION FOUND
Integer optimization begins...
Long-step dual simplex will be used
+ 10417: mip =     not found yet >=              -inf        (1; 0)
+ 10438: >>>>>   2.709167410e+04 >=   2.708690811e+04 < 0.1% (16; 0)
+ 10477: mip = 

100%|███████████████████████████████████████████████████████████████████████████████| 1441/1441 [01:33<00:00, 15.33it/s]


GLPK Integer Optimizer 5.0
22416 rows, 20738 columns, 31011 non-zeros
20738 integer variables, all of which are binary
Preprocessing...
11208 rows, 20738 columns, 31011 non-zeros
20738 integer variables, all of which are binary
Scaling...
 A: min|aij| =  1.000e+00  max|aij| =  1.000e+00  ratio =  1.000e+00
Problem data seem to be well scaled
Constructing initial basis...
Size of triangular part is 11208
Solving LP relaxation...
GLPK Simplex Optimizer 5.0
11208 rows, 20738 columns, 31011 non-zeros
*     0: obj =   5.898442483e+04 inf =   0.000e+00 (7433)
Perturbing LP to avoid stalling [557]...
Removing LP perturbation [7305]...
*  7305: obj =   1.977132116e+04 inf =   1.332e-15 (0) 5
OPTIMAL LP SOLUTION FOUND
Integer optimization begins...
Long-step dual simplex will be used
+  7305: mip =     not found yet >=              -inf        (1; 0)
+  7318: >>>>>   1.978664682e+04 >=   1.977801456e+04 < 0.1% (13; 0)
+  7362: mip =   1.978664682e+04 >=     tree is empty   0.0% (0; 45)
INTEGER 