In [1]:
import btrack
import zarr
import os
import napari
from macrohet import dataio
import glob
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import torch
from trackastra.model import Trackastra
from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks
import h5py

!nvcc --version
!nvidia-smi

from cellpose import core, utils, io, models, metrics
use_GPU = core.use_gpu()
yn = ['NO', 'YES']
print(f'>>> GPU activated? {yn[use_GPU]}')
cellpose_model = models.Cellpose(gpu=True, model_type='cyto')

device = "cuda" if torch.cuda.is_available() else "cpu"
# Load a pretrained model
trackastra_model = Trackastra.from_pretrained("general_2d", device=device)

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0
Fri Oct 25 14:41:07 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06              Driver Version: 555.42.06      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A6000               Off |   00000000:65:00.0  On |                  Off |
| 30%   42C    P8             35W /  300W |   48403MiB /  49140MiB |     20%      Default |
|                      

INFO:cellpose.core:Neither TORCH CUDA nor MPS version not installed/working.
INFO:cellpose.core:Neither TORCH CUDA nor MPS version not installed/working.


>>> GPU activated? NO


INFO:cellpose.core:>>>> using CPU
INFO:cellpose.core:>>>> using CPU
INFO:cellpose.models:>> cyto << model set to be used
INFO:cellpose.core:see https://pytorch.org/docs/stable/backends.html?highlight=mkl
INFO:cellpose.models:>>>> loading model /home/dayn/.cellpose/models/cytotorch_0
INFO:cellpose.models:>>>> model diam_mean =  30.000 (ROIs rescaled to this size during training)
INFO:trackastra.model.model:Loading model state from /home/dayn/.trackastra/.models/general_2d/model.pt
INFO:trackastra.model.model_api:Using device cuda


/home/dayn/.trackastra/.models/general_2d already downloaded, skipping.


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [2]:
expt_ID = 'ND0004'
base_dir = f'/mnt/SYNO/macrohet_syno/data/{expt_ID}/'
metadata_fn = glob.glob(os.path.join(base_dir, 'acquisition/Images/Index*xml'))[0]
metadata = dataio.read_harmony_metadata(metadata_fn)  
metadata_path = glob.glob(os.path.join(base_dir, 'acquisition/Assaylayout/*.xml'))[0]
assay_layout = dataio.read_harmony_metadata(metadata_path, assay_layout=True,replicate_number=True)# mask_exist=True,  image_dir = image_dir, image_metadata = metadata)

Reading metadata XML file...


0it [00:00, ?it/s]

Extracting metadata complete!
Reading metadata XML file...
Extracting metadata complete!


In [13]:
viewer = napari.Viewer(title = 'testing nd4 segmentation and tracking')

viewer.add_image(images, channel_axis = 1, 
                 colormap=['magenta', 'green'],
                 blending = 'additive', 
                 contrast_limits=[[0, 1000], [0, 2400]])
viewer.add_labels(segmentation)

<Labels layer 'segmentation' at 0x7f3e049b6e90>

In [13]:
scale = (1,1,1)

# uniting z3 cellpose and trackastra

In [None]:
# define thresholds
segment_size_thresh = 5000
mtb_load_thresh = 480

mtb_channel_ID = 0
mphi_channel_ID = 1

def calculate_msd(x, y):
    # Calculate the displacement between successive frames
    dx = np.diff(x, prepend=x[0])
    dy = np.diff(y, prepend=y[0])
    msd = np.sqrt(dx**2 + dy**2)
    return msd

# Assuming you have an 'image' that corresponds to your segmentation
def measure_segment_intensity(image, segmentation, ID):
    """Measure intensity of pixels under a specific segment in the image."""
    mask = segmentation == ID  # Create a mask for the segment with 'seg_id'
    segment_intensity = image[mask]  # Extract pixel values under the segment mask
    return np.mean(segment_intensity) #, np.sum(segment_intensity)  # Example measurements


expt_ID = 'ND0004'
base_dir = f'/mnt/SYNO/macrohet_syno/data/{expt_ID}/'
metadata_fn = glob.glob(os.path.join(base_dir, 'acquisition/Images/Index*xml'))[0]
metadata = dataio.read_harmony_metadata(metadata_fn)  
metadata_path = glob.glob(os.path.join(base_dir, 'acquisition/Assaylayout/*.xml'))[0]
assay_layout = dataio.read_harmony_metadata(metadata_path, assay_layout=True,replicate_number=True)# mask_exist=True,  image_dir = image_dir, image_metadata = metadata)
image_resolution = float(metadata['ImageResolutionX'].iloc[0])
meters_area_per_pixel = image_resolution**2
mum_sq_scale_factor = (1E-6)**2
pixel_to_mum_sq_scale_factor = meters_area_per_pixel/mum_sq_scale_factor


# Sort by custom order of the Compound and ConcentrationEC
compound_order = ['RIF', 'PZA', 'INH', 'CTRL', 'BDQ']
concentration_order = ['EC99', 'EC50', 'EC0']

# Define custom sort logic for the DataFrame
assay_layout['compound_sort'] = assay_layout['Compound'].apply(lambda x: compound_order.index(x) if x in compound_order else len(compound_order))
assay_layout['concentration_sort'] = assay_layout['ConcentrationEC'].apply(lambda x: concentration_order.index(x) if x in concentration_order else len(concentration_order))
assay_layout['strain_sort'] = assay_layout['Strain'].apply(lambda x: 0 if x == 'WT' else (1 if x == 'RD1' else 2))

# Sort the DataFrame based on the defined sort order
assay_layout_sorted = assay_layout.sort_values(by=['concentration_sort', 'compound_sort', 'strain_sort'])

# Extract the row-column tuples from the sorted DataFrame index
row_col_order = list(assay_layout_sorted.index)

In [None]:
for acq_ID in tqdm(row_col_order, total = len(row_col_order), desc = 'Iterating over individual wells'):
    
    if os.path.exists(f'/mnt/SYNO/macrohet_syno/data/{expt_ID}/labels/sc_df_{acq_ID[0]}.{acq_ID[1]}.{expt_ID}.pkl'):
        print(f'skipping acq ID {acq_ID}')
        continue

    if acq_ID == (4, 7):
        continue
        
    #technical replicate
    technical_replicate = assay_layout.loc[(acq_ID[0], acq_ID[1])]['Replicate #']
    #biological replicate
    biological_replicate = 4
    #strain
    strain = assay_layout.loc[(acq_ID[0], acq_ID[1])]['Strain']
    #compound
    compound = assay_layout.loc[(acq_ID[0], acq_ID[1])]['Strain']
    #concentration
    concentration = assay_layout.loc[(acq_ID[0], acq_ID[1])]['Concentration']
    # load images and max project them
    image_dir = os.path.join(base_dir, f'acquisition/zarr/{acq_ID}.zarr')
    zarr_group = zarr.open(image_dir, mode='r')
    images = zarr_group.images[...]
    # load the original cellpose segmentation
    with btrack.io.HDF5FileHandler(os.path.join(f'/mnt/SYNO/macrohet_syno/data/{expt_ID}/labels/cpv3/{acq_ID}.h5'), #macrohet_seg_model 
                                               'r', 
                                               obj_type='obj_type_1'
                                               ) as reader:
                    original_segmentation = reader.segmentation
    # calculate the new cellpose segmentation
    segmentation_input = images[:,1,2,...]
    mask_stack = []
    for frame in tqdm(segmentation_input, total = len(segmentation_input)):
        masks, flows, styles, diams = cellpose_model.eval(frame, 
                                                 diameter=150, 
                                                 channels=[0,0],)
    
        mask_stack.append(masks)
    tracking_input_segmentation = np.stack(mask_stack, axis = 0)
    # perform checks on segmentation
    
    # Check if the first frame is blank and, if so, copy from the first non-blank frame
    if np.all(tracking_input_segmentation[0] == 0):
        # Find the first non-blank frame
        first_non_blank_found = False
        for j in range(1, tracking_input_segmentation.shape[0]):
            if not np.all(tracking_input_segmentation[j] == 0):  # Check if frame j is non-blank
                tracking_input_segmentation[0] = tracking_input_segmentation[j]
                first_non_blank_found = True
                print(f"First frame was blank. Copied from frame {j}.")
                break
        if not first_non_blank_found:
            error_log.append("All frames are blank. Skipping segmentation processing.")
            print("Error: All frames are blank. No processing will be done.")
            # Exit if all frames are blank since there's nothing to process
            raise ValueError("Segmentation data is entirely blank.")
     # Iterate over frames and check for blank frames, starting from the second frame
    for i in range(1, tracking_input_segmentation.shape[0]):
        if np.all(tracking_input_segmentation[i] == 0):  # Check if the current frame is blank
            # Copy the previous non-blank frame if available
            if np.all(tracking_input_segmentation[i-1] == 0):
                error_log.append(f"Frame {i} and previous frames are blank, unable to copy.")
                print(f"Error: Frame {i} is blank and cannot be copied from previous frames.")
                continue
            else:
                tracking_input_segmentation[i] = tracking_input_segmentation[i-1]
                print(f"Frame {i} was blank. Copied from frame {i-1}.")
        else:
            # Track the last non-blank frame as we go
            last_non_blank_frame = tracking_input_segmentation[i]
    # Log the error messages and proceed with the loop
    if error_log:
        for error in error_log:
            print(error)
        print("Segmentation completed with some frames skipped due to consecutive blank frames.")

    # Assuming 'segmentation' is your numpy array
    with h5py.File(f'/mnt/SYNO/macrohet_syno/data/{expt_ID}/labels/z3_cpv3/{acq_ID}.h5', 'w') as f:
        # Save the array with compression for efficient storage
        f.create_dataset('segmentation', data=tracking_input_segmentation, compression="gzip")
    # track using new segmentation
    # Track the cells
    track_graph = trackastra_model.track(segmentation_input, tracking_input_segmentation, mode="greedy_nodiv")  # or mode="ilp", or "greedy_nodiv"
    
    # Visualise in napari
    tracks, napari_tracks_graph, _ = graph_to_napari_tracks(track_graph)
    tracks = pd.DataFrame(tracks, columns=['ID', 't', 'x', 'y']).astype(int)
    # Filter tracks with length greater than 75
    tracks = tracks.groupby('ID').filter(lambda x: len(x) > 75)
    
    # split channels
    mphi_channel = np.max(images[:,mphi_channel_ID,...], axis=1)
    mtb_channel = np.max(images[:,mtb_channel_ID,...], axis = 1)
    thresholded_mtb_channel = mtb_channel >= mtb_load_thresh
    track_dfs = []
    # now measure properties from prior segmentation?
    for cell_ID, track in tqdm(tracks.groupby('ID'), desc = 'iterating over tracks', total = len(tracks.ID.unique()), leave = False):
        track = track.sort_values(by='t')
        # Extract coordinates and time
        times = track['t'].to_numpy() / 2  # Assuming you want to halve the time
        frames = track['t'].to_numpy() 
    
        x_coords = track['x'].to_numpy().astype(int)
        y_coords = track['y'].to_numpy().astype(int)
    
    
        mtb_areas = []
        mphi_areas = []
        mean_intensities = []
        # calculate the mtb pixel area and µm area
        for i, frame in tqdm(enumerate(frames), desc = f'Iterating over frames for track ID {cell_ID}', total = len(frames), leave = False):
            frame = frame - 1
            segmentation_input_ID = original_segmentation[frame][x_coords[i], y_coords[i]]
            if segmentation_input_ID == 0:  # Ignore background (assuming 0 is the background ID)
                    mtb_area_pixels = np.nan
                    mphi_area_pixels = np.nan
            else:
                
                mask = original_segmentation[frame] == segmentation_input_ID  # Create a mask for the segment with 'seg_id'
                
                # chop up images into segments here
                # image_segment = images[frame][:][mask]
                thresholded_image_segment = thresholded_mtb_channel[frame][mask]
        
                # meaure segment
                mtb_area_pixels = np.sum(thresholded_image_segment)
                mphi_area_pixels = np.sum(mask)
                # mean_intensity = np.mean(image_segment)
    
            # store measurements
            mtb_areas.append(mtb_area_pixels)
            mphi_areas.append(mphi_area_pixels)
            # mean_intensities.append(mean_intensity)
    
        track['Mtb Area (µm)'] = np.array(mtb_areas) * pixel_to_mum_sq_scale_factor
        track['Mphi Area (µm)'] = np.array(mphi_areas) * pixel_to_mum_sq_scale_factor
        # track['RFP'] = mean_intensities[mtb_channel_ID]
        # track['GFP'] = mean_intensities[mphi_channel_ID]
        # Compute MSD in a vectorized way
        # track['MSD'] = calculate_msd(x_coords, y_coords)
        
        # infection statuses
        track['Infection Status'] = track['Mtb Area (µm)'] > 0
    
        track['Initial Infection Status'] = track['Mtb Area (µm)'].iloc[0] > 0 
        track['Final Infection Status'] = track['Mtb Area (µm)'].iloc[-1] > 0 
        track['ID'] = cell_ID
        track['Unique_ID'] = f'{cell_ID}.{acq_ID[0]}.{acq_ID[1]}.{expt_ID}'
    
        track_dfs.append(track)
    
    # Concatenate all track DataFrames into the larger 'df' DataFrame
    df = pd.concat(track_dfs, ignore_index=True)
    
    df.to_pickle(f'/mnt/SYNO/macrohet_syno/data/{expt_ID}/labels/sc_df_{acq_ID[0]}.{acq_ID[1]}.{expt_ID}.pkl')

# Checking segentation matching


In [15]:
color_by = 'ID'

In [62]:
### checking segentation matching

# Initialize relabeled array
relabeled = np.zeros_like(segmentation)

# Iterate over each track, grouped by 'ID'
for cell_ID, track in tqdm(tracks.groupby('ID'), desc = 'Iterating over tracks', total = len(tracks.ID.unique())):
    track = track.sort_values(by='T')
    times = (track['T'].to_numpy()).astype(int)  # Time (T) halved and converted to int
    x_coords = (track['X'].to_numpy() * scale[0]).astype(int)
    y_coords = (track['Y'].to_numpy() * scale[1]).astype(int)

    # Iterate over each time point
    for i, t in enumerate(times):
        # Ensure we are within the segmentation time bounds
        if t >= segmentation.shape[0]:
            continue

        # Handle 2D segmentation
        old_id = segmentation[t][x_coords[i], y_coords[i]]
        if old_id == 0:  # Ignore background (assuming 0 is the background ID)
                continue
        # Recolor segmentation by the chosen property (ID in this case)
        old_id_mask = segmentation[t] == old_id
        # print(old_id)
        # Recolor all pixels in this mask with the new ID
        new_id = int(cell_ID)
        relabeled[t][old_id_mask] = new_id

Iterating over tracks:   0%|          | 0/162 [00:00<?, ?it/s]