In [1]:
import os
import cell_lineage_tracking as lineage
import gnn_ben_haim as bh_track
import tifffile
import matplotlib.pyplot as plt
from skimage.measure import find_contours
from skimage.draw import polygon
from matplotlib.collections import LineCollection
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
torch.set_default_dtype(torch.float32)
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import BaseTransform
import torch.nn as nn
from torch.utils.data import Subset
import lineage_detection

In [2]:
strain_exp_dict = {# need araB gitg61/CL08 lineage tracing CL008_giTG068_072925 pos 0-10 is CL08 # don't see fluor signal
                  # need alkA gitg63 lineage tracing potentially DUMM_giTG63_giTG67_Glucose_121724_1 pos 5, 'DUMM_giTG068_063_061725' look at pos 1-8, had pos 0 but bad for lineage
                  # need murQ gitg67 lineage tracing potenitally DUMM_giTG068_067_061825 or DUMM_giTG63_giTG67_Glucose_121724_1 pos 6-9

                  'chpS': ['DUMM_giTG62_Glucose_012925'],  #need lineage tracing, Nora working on it
                  'baeS':['DUMM_giTG66_Glucose_012325'],  # need to find new exp, lineage tracing is bad on this one
                  'lacZ':['DUMM_giTG059_noKan_Glucose_031125'],
                  'gfcE': ['DUMM_giTG064_Glucose_022625'],
                  'gldA': ['DUMM_giTG69_Glucose_013025'],
                  #'murQ': [' DUMM_giTG63_giTG67_Glucose_121724_1'],
                  'alkA': ['DUMM_giTG068_063_061725_v2','DUMM_giTG63_giTG67_Glucose_121724_1_v2'],
                  'mazF': ['DUMM_giTG059_060_061125'], # constitutive
                  'hupA':['DUMM_giTG068_052925', 'DUMM_giTG068_063_061725'] # constitutive
                  }

In [3]:
df_all = pd.DataFrame()
for gene, exps in strain_exp_dict.items():
    for exp_view in exps:
        print(exp_view)
        all_cells_filename = f'/Users/noravivancogonzalez/Documents/DuMM_image_analysis/all_cell_data_{exp_view}.pkl'
        all_cells_pd = pd.read_pickle(all_cells_filename)
        all_cells_pd['gene'] = gene
        if df_all.empty:
            df_all = all_cells_pd
        else:
            df_all = pd.concat([df_all, all_cells_pd], ignore_index=True)

DUMM_giTG62_Glucose_012925
DUMM_giTG66_Glucose_012325
DUMM_giTG059_noKan_Glucose_031125
DUMM_giTG064_Glucose_022625
DUMM_giTG69_Glucose_013025
DUMM_giTG068_063_061725_v2
DUMM_giTG63_giTG67_Glucose_121724_1_v2
DUMM_giTG059_060_061125
DUMM_giTG068_052925
DUMM_giTG068_063_061725


In [4]:
df_all['node_id'] = df_all.index

In [5]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [6]:
node_feature_cols = ['area', 'centroid_y', 
       'axis_major_length', 'axis_minor_length', 'intensity_mean_phase',
       'intensity_max_phase', 'intensity_min_phase', 'intensity_mean_fluor',
       'intensity_max_fluor', 'intensity_min_fluor']
for col in node_feature_cols:
    df_all[col] = df_all[col].astype(np.float32);   

In [7]:
# Access the node features from the single train_df DataFrame.
all_train_node_features_df = df_all[node_feature_cols].values.astype(np.float32)

# Initialize and fit the scaler on the training data.
scaler = StandardScaler()
scaler.fit(all_train_node_features_df)
print("Scaler fitted on training data.")

# Create an instance of the transform with the fitted scaler.
transform = bh_track.StandardScalerTransform(scaler)

dataset = bh_track.CellTrackingDataset(root='./processed_data_bh/full_dataset',
                                   df_cells=df_all,
                                   node_feature_cols=node_feature_cols,
                                   device=device,
                                   pre_transform=transform)

batch_size = 32
test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

print(f"Number of test graphs: {len(dataset)}")

Scaler fitted on training data.
Starting data processing for processed_data_bh/full_dataset...


Processing...


Finished processing. Created 50 PyG Data objects.
Processed data saved to processed_data_bh/full_dataset/processed/cell_tracking_fovs.pt
Number of test graphs: 1


Done!


In [8]:
num_node_features = len(node_feature_cols)
initial_edge_feature_dim = len(node_feature_cols) + 1
hidden_channels = 128 # balance between expressiveness and compute

model = bh_track.LineageLinkPredictionGNN(in_channels=num_node_features,
    initial_edge_channels=initial_edge_feature_dim, 
    hidden_channels=128,
    num_blocks=2).to(device)
model.load_state_dict(torch.load('best_link_prediction_model.pt'))
print("Loaded best model cell linkage.")

lineage_predictions_df = bh_track.predict_cell_linkages(model,
                                                       test_loader,
                                                       device)

Loaded best model cell linkage.


In [9]:
df_consolidated_lineages =  lineage_detection.process_all_fovs(lineage_predictions_df, df_all, prob_threshold=0.8)

Column 'ground_truth_lineage' found. Including it in node attributes.
--- Processing FOV 1/50: DUMM_giTG62_Glucose_012925, 005, 445 ---
--- Adding nodes with all experimental data... ---
--- Building initial graph with edges... ---
Initial graph has 390 nodes and 304 edges.

All cycles have been removed.
Final graph has 390 nodes and 304 edges.
Successfully generated 304 lineage rows.
--- Processing FOV 2/50: DUMM_giTG62_Glucose_012925, 005, 900 ---
--- Adding nodes with all experimental data... ---
--- Building initial graph with edges... ---
Initial graph has 109 nodes and 96 edges.

All cycles have been removed.
Final graph has 109 nodes and 96 edges.
Successfully generated 96 lineage rows.
--- Processing FOV 3/50: DUMM_giTG62_Glucose_012925, 005, 1128 ---
--- Adding nodes with all experimental data... ---
--- Building initial graph with edges... ---
Initial graph has 269 nodes and 215 edges.

All cycles have been removed.
Final graph has 269 nodes and 215 edges.
Successfully genera

In [10]:
df_for_kymograph_plot = lineage_detection.consolidate_lineages_to_node_df(df_consolidated_lineages, df_all)

--- Consolidating lineage data to a node-centric format... ---
--- Consolidation complete. Final DataFrame is ready for plotting. ---


In [15]:
tracked_all_cells_filename = f'/Users/noravivancogonzalez/Documents/DuMM_image_analysis/20250829_tracked_all_cell_data.pkl'
df_for_kymograph_plot.to_pickle(tracked_all_cells_filename)

In [2]:
tracked_all_cells_filename = f'/Users/noravivancogonzalez/Documents/DuMM_image_analysis/20250829_tracked_all_cell_data.pkl'
df_for_kymograph_plot = pd.read_pickle(tracked_all_cells_filename)

In [3]:
df_for_kymograph_plot.head(5)

Unnamed: 0,label,area,coords,centroid_y,centroid_x,axis_major_length,axis_minor_length,intensity_mean_phase,intensity_max_phase,intensity_min_phase,...,intensity_min_fluor,time_frame,experiment_name,FOV,trench_id,track_id,node_id,ground_truth_lineage,gene,predicted_lineage
0,1,182.0,"[[0, 166], [0, 167], [0, 168], [0, 169], [0, 1...",14.615385,169.423077,37.096642,6.667422,8683.736328,11092.0,6837.0,...,6.0,8.0,DUMM_giTG62_Glucose_012925,5,445,19,0,,chpS,5.0
1,3,233.0,"[[0, 226], [0, 227], [0, 228], [0, 229], [0, 2...",16.566525,230.412017,42.404186,7.669648,8531.793945,10206.0,6804.0,...,0.0,11.0,DUMM_giTG62_Glucose_012925,5,445,30,2,,chpS,11.0
2,5,165.0,"[[0, 325], [0, 326], [0, 327], [0, 328], [0, 3...",10.478787,328.666667,26.321367,8.304471,8684.884766,11109.0,6299.0,...,0.0,16.0,DUMM_giTG62_Glucose_012925,5,445,49,4,,chpS,4.1
3,7,198.0,"[[0, 2087], [0, 2088], [0, 2089], [1, 2087], [...",16.934343,2086.919192,38.740509,6.952181,8277.232422,10729.0,6501.0,...,0.0,104.0,DUMM_giTG62_Glucose_012925,5,445,123,6,,chpS,6.0
4,11,223.0,"[[0, 2203], [0, 2204], [0, 2205], [0, 2206], [...",16.847534,2207.161435,41.267384,7.352654,8576.610352,11085.0,6741.0,...,0.0,110.0,DUMM_giTG62_Glucose_012925,5,445,141,10,,chpS,9.0


In [12]:
def plot_kymograph_cells_id(phase_kymograph, fluor_kymograph, full_region_df, folder, fov_id, peak_id, track_id_col='track_id'):
    fig, ax = plt.subplots(1,1, figsize=(40, 10))

    # Get kymograph shape once for both calls
    kymograph_shape = phase_kymograph.shape

    ax.imshow(phase_kymograph, cmap = 'grey')
    _plot_cell_masks(ax, full_region_df, kymograph_shape, y_coord_col = 'centroid_y', x_coord_col = 'centroid_x', lineage_col = track_id_col)
    ax.set_yticks([])
    ax.set_xticks([])
    ax.set_title(f'Phase Kymograph - {folder} FOV: {fov_id}, trench: {peak_id}')

    plt.xlabel("Time frames")
    plt.tight_layout()

    plt.savefig(f'{folder}_FOV_{fov_id}_trench_{peak_id}_kymograph.png', dpi=300, bbox_inches='tight')
    plt.close(fig)

def _plot_cell_masks(ax, full_region_df, kymograph_shape, y_coord_col = 'centroid_y', x_coord_col = 'centroid_x', lineage_col = None):
    default_cell_contour_color = '#AA5486'
    default_centroid_color = '#FC8F54'
    text_color = 'white' # Color for the track ID text
    text_offset_x = 5 # Offset the text slightly from the centroid to avoid overlap
    text_offset_y = 0

    # Prepare data for LineCollection for contours
    all_contours_segments = []
    all_contour_colors = []
    
    # Prepare data for centroids
    centroid_x_coords = []
    centroid_y_coords = []
    centroid_colors = []

    all_text_params = []

    if lineage_col:
        unique_track_ids = full_region_df[lineage_col].dropna().unique()
        colors_cmap = plt.get_cmap('tab20', len(unique_track_ids))
        track_colors = {track_id: colors_cmap(i) for i, track_id in enumerate(unique_track_ids)}

    for index, region_props in full_region_df.iterrows():
        # 'coords' are assumed to be (row, col) pixels within the mask
        cell_pixel_coords = np.array(region_props['coords']) # e.g., [[r1,c1], [r2,c2], ...]

        # Calculate bounding box for the current cell's mask
        min_row, min_col = np.min(cell_pixel_coords, axis=0)
        max_row, max_col = np.max(cell_pixel_coords, axis=0)

        # Create a small temporary mask for the current cell
        # Add a small buffer to ensure contours are fully captured if they go to edge
        buffer = 1
        bbox_min_row = max(0, min_row - buffer)
        bbox_min_col = max(0, min_col - buffer)
        bbox_max_row = min(kymograph_shape[0], max_row + buffer)
        bbox_max_col = min(kymograph_shape[1], max_col + buffer)

        temp_mask_shape = (bbox_max_row - bbox_min_row + 1, bbox_max_col - bbox_min_col + 1)
        temp_mask = np.zeros(temp_mask_shape, dtype=np.uint8)

        # Map cell_pixel_coords to relative coordinates within temp_mask
        relative_rows = cell_pixel_coords[:, 0] - bbox_min_row
        relative_cols = cell_pixel_coords[:, 1] - bbox_min_col
        
        # Populate the temporary mask
        temp_mask[relative_rows, relative_cols] = 1

        # Find contours on this small temporary mask
        # level=0.5 means it finds contours at the boundary between 0 and 1
        # fully_connected='high' means it considers 8-connectivity for background, 4-connectivity for foreground
        contours = find_contours(temp_mask, level=0.5, fully_connected='high')

        if not contours:
            continue # Skip if no contour found (e.g., single pixel or degenerate mask)

        # `find_contours` returns (row, col) coordinates for the contour.
        # We need to convert them back to global kymograph coordinates.
        # And convert to (x, y) for plotting (col, row)
        global_contours = []
        for contour in contours:
            # Shift back to global coordinates and swap for (x, y) plotting
            global_contour_x = contour[:, 1] + bbox_min_col
            global_contour_y = contour[:, 0] + bbox_min_row
            global_contours.append(np.vstack([global_contour_x, global_contour_y]).T)

        y_coord = region_props[y_coord_col]
        x_coord = region_props[x_coord_col]
        
        # Determine color for the current cell
        if lineage_col and region_props[lineage_col] in track_colors:
            current_color = track_colors[region_props[lineage_col]]
        else:
            current_color = default_cell_contour_color

        # Add all contours for this cell to the main list, with the determined color
        for contour_segment in global_contours:
            all_contours_segments.append(contour_segment)
            all_contour_colors.append(current_color)
            
        # Add centroid data
        centroid_x_coords.append(x_coord)
        centroid_y_coords.append(y_coord)
        centroid_colors.append(current_color if lineage_col else default_centroid_color)
        
        # --- Add Track ID Text ---
        if lineage_col and pd.notna(region_props[lineage_col]):
            track_id = region_props[lineage_col]
            # Convert track_id to int if it's a float, for cleaner display
            if isinstance(track_id, float) and track_id.is_integer():
                track_id_display = int(track_id)
            else:
                track_id_display = track_id

            all_text_params.append({
                'x': x_coord + text_offset_x,
                'y': y_coord + text_offset_y,
                's': str(track_id_display),
                'color': text_color,
                'fontsize': 8,
                'ha': 'left', # horizontal alignment
                'va': 'center', # vertical alignment
                'bbox': dict(facecolor=current_color, edgecolor='none', alpha=0.6, pad=1.0) # Background box
            })

    # Plot all cell contours at once using LineCollection
    if all_contours_segments: # Only plot if there are segments to draw
        line_collection = LineCollection(all_contours_segments, colors=all_contour_colors, linewidths=0.5)
        ax.add_collection(line_collection)

    # Plot all centroids at once using scatter
    if centroid_x_coords: # Only plot if there are centroids
        ax.scatter(centroid_x_coords, centroid_y_coords, color=centroid_colors, s=5, zorder=2)
    # Plot all track ID text at once
    for params in all_text_params:
        ax.text(params['x'], params['y'], params['s'], color=params['color'],
                fontsize=params['fontsize'], ha=params['ha'], va=params['va'],
                bbox=params['bbox'])

In [13]:
unique_fovs = df_for_kymograph_plot[['experiment_name', 'FOV', 'trench_id']].drop_duplicates().to_records(index=False)

In [None]:
for cell in unique_fovs:
    exp, fov, trench = cell
    # path_to_phase_kymograph = f'/Users/noravivancogonzalez/Documents/DuMM_image_analysis/{exp}/hyperstacked/drift_corrected/fast4deg_drift_corrected/rotated/mm_channels/subtracted/{fov}_{trench}.tif'
    # path_to_fluor_kymograph = f'/Users/noravivancogonzalez/Documents/DuMM_image_analysis/{exp}//hyperstacked/drift_corrected/fast4deg_drift_corrected/rotated/mm_channels/subtracted/fluor/{fov}_{trench}.tif'
    
    
    path_to_phase_kymograph = f'/Users/noravivancogonzalez/Documents/DuMM_image_analysis/{exp}/hyperstacked/drift_corrected/rotated/mm_channels/subtracted/{fov}_{trench}.tif'
    path_to_fluor_kymograph = f'/Users/noravivancogonzalez/Documents/DuMM_image_analysis/{exp}/hyperstacked/drift_corrected/rotated/mm_channels/subtracted/fluor/{fov}_{trench}.tif'
    
    # path_to_phase_kymograph = f'/Users/noravivancogonzalez/Documents/DuMM_image_analysis/{exp}/kymographs/phase/{fov}_{trench}.tif'
    # path_to_fluor_kymograph = f'/Users/noravivancogonzalez/Documents/DuMM_image_analysis/{exp}//kymographs/fluor/{fov}_{trench}.tif'
    
    if os.path.exists(path_to_phase_kymograph) and os.path.exists(path_to_fluor_kymograph):
        print(cell)
    
        phase_kymograph = tifffile.imread(path_to_phase_kymograph)
        fluor_kymograph = tifffile.imread(path_to_fluor_kymograph)
        df_view = df_for_kymograph_plot[df_for_kymograph_plot['experiment_name'].isin([exp]) &
                                  df_for_kymograph_plot['FOV'].isin([fov]) &
                                  df_for_kymograph_plot['trench_id'].isin([trench])].copy()
        plot_kymograph_cells_id(phase_kymograph, fluor_kymograph, 
                                    df_view, 
                                    exp, fov, trench, 
                                    track_id_col='track_id')
        plot_kymograph_cells_id(phase_kymograph, fluor_kymograph, 
                                    df_view, 
                                    exp, fov, trench, 
                                    track_id_col='predicted_lineage')