In [22]:
import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing,GCNConv, global_mean_pool
from torch_geometric.utils import degree # For potentially normalizing attention
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import add_self_loops, negative_sampling
from sklearn.model_selection import train_test_split
import tifffile
import matplotlib.pyplot as plt
from skimage.measure import find_contours
from skimage.draw import polygon
from matplotlib.collections import LineCollection
import pandas as pd
import numpy as np
from scipy.spatial.distance import euclidean

In [2]:
folder = 'DUMM_giTG69_Glucose_013025'
all_cells_filename = f'/Users/noravivancogonzalez/Documents/DuMM_image_analysis/all_cell_data_{folder}.pkl'
all_cells_pd = pd.read_pickle(all_cells_filename)
FOV = '007'
trench_id = '295'
df = all_cells_pd[(all_cells_pd['FOV'] == FOV) & (all_cells_pd['trench_id'] == trench_id)].copy()
df['track_id'] = df['track_id'].astype(int).astype(str)

In [3]:
#for fov 007 and trench id 295
# manually correcting lineages id from track id
ground_truth_lineage_id_dict = {'7': 'A',
                                '25':'A.1',
                                '46':'A.1',
                                '67':'A.1',
                                '73':'A.1.1',
                                '79':'A.1.1',
                                '97':'A.1.1',
                                '102':'A.1.1',
                                '108':'A.1.1',
                                '115':'A.1.1',
                                '119':'A.1.1',
                                '124':'A.1.1',
                                '130':'A.1.1',
                                '132':'A.1.1',
                                '135':'A.1.1',
                                '140':'A.1.1',
                                '142':'A.1.1',
                                '143':'A.1.1.1',
                                '151':'A.1.1.1',
                                '154':'A.1.1.1',
                                '159':'A.1.1.1',
                                '163':'A.1.1.1',
                                '167':'A.1.1.1',
                                '172':'A.1.1.1',
                                '178':'A.1.1.1',
                                '183':'A.1.1.1',
                                '144':'A.1.1.2',
                                '152':'A.1.1.2',
                                '155':'A.1.1.2',
                                '160':'A.1.1.2',
                                '164':'A.1.1.2',
                                '168':'A.1.1.2',
                                '173':'A.1.1.2',
                                '179':'A.1.1.2',
                                '184':'A.1.1.2',
                                '187':'A.1.1.2',
                                '191':'A.1.1.2',
                                '196':'A.1.1.2',
                                '202':'A.1.1.2',
                                '74':'A.1.2',
                                '80':'A.1.2',
                                '93':'A.1.2',
                                '109':'A.1.2',
                                '120':'A.1.2',
                                '125':'A.1.2',
                                '131':'A.1.2',
                                '133':'A.1.2',
                                '136':'A.1.2',
                                '141':'A.1.2',
                                '145':'A.1.2.1',
                                '150':'A.1.2.1',
                                '156':'A.1.2.1',
                                '161':'A.1.2.1',
                                '169':'A.1.2.1',
                                '174':'A.1.2.1',
                                '180':'A.1.2.1',
                                '188':'A.1.2.1',
                                '192':'A.1.2.1',
                                '197':'A.1.2.1',
                                '203':'A.1.2.1',
                                '208':'A.1.2.1',
                                '213':'A.1.2.1',
                                '146':'A.1.2.2',
                                '162':'A.1.2.2',
                                '165':'A.1.2.2',
                                '170':'A.1.2.2',
                                '175':'A.1.2.2',
                                '181':'A.1.2.2',
                                '193':'A.1.2.2',
                                '198':'A.1.2.2',
                                '204':'A.1.2.2',
                                '209':'A.1.2.2',
                                '216':'A.1.2.2',
                                '222':'A.1.2.2',
                                '74':'A.1.2',
                                '19': 'A.2',
                                '81': 'A.2.1',
                                '88':'A.2.1',
                                '94':'A.2.1',
                                '137':'A.2.1',
                                '147':'A.2.1.1',
                                '166':'A.2.1.1',
                                '176':'A.2.1.1',
                                '182':'A.2.1.1',
                                '189':'A.2.1.1',
                                '194':'A.2.1.1',
                                '199':'A.2.1.1',
                                '205':'A.2.1.1',
                                '210':'A.2.1.1',
                                '217':'A.2.1.1',
                                '223':'A.2.1.1',
                                '229':'A.2.1.1',
                                '234':'A.2.1.1',
                                '148':'A.2.1.2',
                                '157':'A.2.1.2',
                                '177':'A.2.1.2',
                                '195':'A.2.1.2',
                                '206':'A.2.1.2',
                                '211':'A.2.1.2',
                                '214':'A.2.1.2',
                                '218':'A.2.1.2',
                                '224':'A.2.1.2.1',
                                '230':'A.2.1.2.1',
                                '235':'A.2.1.2.1',
                                '240':'A.2.1.2.1',
                                '245':'A.2.1.2.1',
                                '249':'A.2.1.2.1',
                                '225':'A.2.1.2.2',
                                '82': 'A.2.2',
                                '138':'A.2.2.1',
                                '158':'A.2.2.1',
                                '171':'A.2.2.1',
                                '190':'A.2.2.1',
                                '207':'A.2.2.1',
                                '212':'A.2.2.1',
                                '220':'A.2.2.1',
                                '227':'A.2.2.1',
                                '232':'A.2.2.1',
                                '238':'A.2.2.1',
                                '243':'A.2.2.1',
                                '247':'A.2.2.1',
                                '255':'A.2.2.1',
                                '258':'A.2.2.1',
                                '262':'A.2.2.1',
                                '264':'A.2.2.1',
                                '267':'A.2.2.1.1',
                                '268':'A.2.2.1.2',
                                '139':'A.2.2.2',
                                '200': 'A.2.2.2.1',
                                '215': 'A.2.2.2.1',
                                '221': 'A.2.2.2.1',
                                '228': 'A.2.2.2.1',
                                '233': 'A.2.2.2.1',
                                '239': 'A.2.2.2.1',
                                '244': 'A.2.2.2.1',
                                '248': 'A.2.2.2.1',
                                '252': 'A.2.2.2.1',
                                '256': 'A.2.2.2.1',
                                '259': 'A.2.2.2.1',
                                '263': 'A.2.2.2.1',
                                '265': 'A.2.2.2.1',
                                '266': 'A.2.2.2.1',
                                '269': 'A.2.2.2.1',
                                '272': 'A.2.2.2.1',
                                '274': 'A.2.2.2.1',
                                '201': 'A.2.2.2.2',
                                '273': 'A.2.2.2.2',
                                '277': 'A.2.2.2.2',
                                '278': 'A.2.2.2.2.1',
                                '280': 'A.2.2.2.2.1',
                                '281': 'A.2.2.2.2.1',
                                '282': 'A.2.2.2.2.1',
                                '283': 'A.2.2.2.2.1',
                                '284': 'A.2.2.2.2.1',
                                '285': 'A.2.2.2.2.1',
                                '286': 'A.2.2.2.2.1',
                                '287': 'A.2.2.2.2.1.1',
                                '288': 'A.2.2.2.2.1.2',
                                '279': 'A.2.2.2.2.2',
                                '289': 'A.2.2.2.2.1',
                                '291': 'A.2.2.2.2.1',
                                '292': 'A.2.2.2.2.1',
                                '293': 'A.2.2.2.2.1',
                                '294': 'A.2.2.2.2.1',
                                '295': 'A.2.2.2.2.1',
                                '296': 'A.2.2.2.2.1',
                                '297': 'A.2.2.2.2.1',
                                '298': 'A.2.2.2.2.1',
                                '290': 'A.2.2.2.2.2.2',
                                '299': 'A.2.2.2.2.2.2',
                                '300': 'A.2.2.2.2.2.2',
                                '301': 'A.2.2.2.2.2.2',
                                '302': 'A.2.2.2.2.2.2.1',
                                '304': 'A.2.2.2.2.2.2.1',
                                '305': 'A.2.2.2.2.2.2.1',
                                '306': 'A.2.2.2.2.2.2.1',
                                '307': 'A.2.2.2.2.2.2.1',
                                '308': 'A.2.2.2.2.2.2.1',
                                '309': 'A.2.2.2.2.2.2.1',
                                '303': 'A.2.2.2.2.2.2.2',
                                '310': 'A.2.2.2.2.2.2.2',
                                '311': 'A.2.2.2.2.2.2.2',
                                '312': 'A.2.2.2.2.2.2.2',
                                '313': 'A.2.2.2.2.2.2.2.1',
                                '315': 'A.2.2.2.2.2.2.2.1',
                                '316': 'A.2.2.2.2.2.2.2.1',
                                '317': 'A.2.2.2.2.2.2.2.1',
                                '318': 'A.2.2.2.2.2.2.2.1',
                                '319': 'A.2.2.2.2.2.2.2.1',
                                '320': 'A.2.2.2.2.2.2.2.1',
                                '314': 'A.2.2.2.2.2.2.2.2',
                                '321': 'A.2.2.2.2.2.2.2.2',
                                '322': 'A.2.2.2.2.2.2.2.2',
                                '323': 'A.2.2.2.2.2.2.2.2',
                                '324': 'A.2.2.2.2.2.2.2.2',
                                '325': 'A.2.2.2.2.2.2.2.2',
                                '326': 'A.2.2.2.2.2.2.2.2',
                                '327': 'A.2.2.2.2.2.2.2.2.1',
                                '329': 'A.2.2.2.2.2.2.2.2.1',
                                '330': 'A.2.2.2.2.2.2.2.2.1',
                                '328': 'A.2.2.2.2.2.2.2.2.2',
                                '331': 'A.2.2.2.2.2.2.2.2.2',
                                '332': 'A.2.2.2.2.2.2.2.2.2.1',
                                '333': 'A.2.2.2.2.2.2.2.2.2.2'}

In [4]:
df['ground_truth_lineage'] = None
df['ground_truth_lineage']= df['track_id'].map(ground_truth_lineage_id_dict)
df.rename(columns = {'centroid-0': 'centroid_y','centroid-1': 'centroid_x'}, inplace = True)
df_cells = df[df['ground_truth_lineage'].notna()]

In [6]:
# Convert lineage IDs to unique integers as labels
all_unique_lineages = sorted(df_cells['ground_truth_lineage'].unique())
lineage_to_int_mapping = {lineage: i for i, lineage in enumerate(all_unique_lineages)}
num_lineage_classes = len(all_unique_lineages)

In [7]:
df_cells['numeric_lineage'] = df_cells['ground_truth_lineage'].map(lineage_to_int_mapping)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cells['numeric_lineage'] = df_cells['ground_truth_lineage'].map(lineage_to_int_mapping)


In [9]:
df_cells['node_id'] = df_cells.index # Assign unique global node ID

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cells['node_id'] = df_cells.index # Assign unique global node ID


# try architecture like https://link.springer.com/chapter/10.1007/978-3-031-19803-8_36

In [44]:
# The function expects a sub-DataFrame already filtered for a specific lineage branch
def create_lineage_graph(df_lineage, device='cpu'):
    original_global_node_ids = df_lineage['node_id'].values
    global_id_to_local_idx = {global_id: i for i, global_id in enumerate(original_global_node_ids)}

    x = torch.tensor(df_lineage[node_feature_cols].to_numpy(dtype=np.float32), dtype=torch.float).to(device)
    
    y = torch.tensor(df_lineage['numeric_lineage'].values, dtype=torch.long).to(device)
    # Ensure pos is 2D if you have both x and y, or handle 1D appropriately in negative sampling
    # If your 'pos' is just centroid_y, then it's [num_nodes]. Need to reshape to [num_nodes, 1]
    # for consistent tensor operations.
    pos = torch.tensor(df_lineage['centroid_y'].values, dtype=torch.float).to(device)
    if pos.dim() == 1: # Reshape if it's just a 1D tensor of y-coords
        pos = pos.unsqueeze(1) # Makes it [num_nodes, 1]

    node_time_frames = torch.tensor(df_lineage['time_frame'].values, dtype=torch.long).to(device)

    num_nodes = len(df_lineage)
    if num_nodes == 0:
        return None

    source_nodes_local_idx = []
    target_nodes_local_idx = []

    sorted_time_frames = sorted(df_lineage['time_frame'].unique())

    # This loop builds the lists of source_nodes_local_idx and target_nodes_local_idx
    for i in range(len(sorted_time_frames) - 1):
        current_t = sorted_time_frames[i]
        next_t = sorted_time_frames[i+1]

        df_current_t = df_lineage[df_lineage['time_frame'] == current_t]
        df_next_t = df_lineage[df_lineage['time_frame'] == next_t]

        current_lineage_to_node = df_current_t.set_index('ground_truth_lineage')['node_id'].to_dict()
        next_lineage_to_node = df_next_t.set_index('ground_truth_lineage')['node_id'].to_dict()

        for idx, row in df_current_t.iterrows():
            current_global_node_id = row['node_id']
            current_ground_truth_lineage = row['ground_truth_lineage']

            if current_ground_truth_lineage in next_lineage_to_node:
                next_global_node_id = next_lineage_to_node[current_ground_truth_lineage]
                source_nodes_local_idx.append(global_id_to_local_idx[current_global_node_id])
                target_nodes_local_idx.append(global_id_to_local_idx[next_global_node_id])

            daughter1_lineage = f"{current_ground_truth_lineage}.1"
            daughter2_lineage = f"{current_ground_truth_lineage}.2"

            if daughter1_lineage in next_lineage_to_node:
                next_global_node_id = next_lineage_to_node[daughter1_lineage]
                source_nodes_local_idx.append(global_id_to_local_idx[current_global_node_id])
                target_nodes_local_idx.append(global_id_to_local_idx[next_global_node_id])
            if daughter2_lineage in next_lineage_to_node:
                next_global_node_id = next_lineage_to_node[daughter2_lineage]
                source_nodes_local_idx.append(global_id_to_local_idx[current_global_node_id])
                target_nodes_local_idx.append(global_id_to_local_idx[next_global_node_id])

    # --- Determine edge_index based on collected source/target lists ---
    if not source_nodes_local_idx: # Case: No edges found at all for this lineage branch
        edge_index = torch.empty((2, 0), dtype=torch.long).to(device)
        # For this case, initial_edge_attr must also be an empty tensor
        initial_edge_attr = torch.empty((0, len(node_feature_cols) + 1), dtype=torch.float).to(device)
    else: # Case: Some potential edges were found
        unique_edges = list(set(zip(source_nodes_local_idx, target_nodes_local_idx)))
        if not unique_edges: # This should theoretically be covered by the first 'if', but acts as a safeguard
            edge_index = torch.empty((2, 0), dtype=torch.long).to(device)
            initial_edge_attr = torch.empty((0, len(node_feature_cols) + 1), dtype=torch.float).to(device)
        else: # Case: Non-empty, unique edges exist
            source_nodes_unique, target_nodes_unique = zip(*unique_edges)
            edge_index = torch.tensor([list(source_nodes_unique), list(target_nodes_unique)], dtype=torch.long).to(device)

            # --- ONLY in this 'else' block, calculate initial_edge_attr for the actual edges ---
            # Get node features for source and target nodes of the created edges
            # (Use df_lineage.iloc with the numpy conversion of edge_index for correct indexing)
            true_src_node_dfs = df_lineage.iloc[edge_index[0].cpu().numpy()]
            true_tgt_node_dfs = df_lineage.iloc[edge_index[1].cpu().numpy()]

            initial_edge_features_list = []
            for i in range(edge_index.size(1)): # Iterate over the actual unique edges
                # Use original node features for initial D-S block

                v_i_raw = torch.tensor(true_src_node_dfs.iloc[i][node_feature_cols].to_numpy(dtype=np.float32), dtype=torch.float)
                v_j_raw = torch.tensor(true_tgt_node_dfs.iloc[i][node_feature_cols].to_numpy(dtype=np.float32), dtype=torch.float)
                initial_edge_features_list.append(DS_block(v_i_raw, v_j_raw))

            initial_edge_attr = torch.stack(initial_edge_features_list, dim=0).to(device)
            # No 'else' needed here, as this block only runs if initial_edge_features_list is guaranteed to be non-empty
            # because edge_index.size(1) > 0.

    data = Data(x=x,
                edge_index=edge_index,
                y=y,
                pos=pos,
                num_nodes=num_nodes,
                time_frame=node_time_frames,
                edge_attr=initial_edge_attr, # This will now always be a correctly shaped tensor
                root_lineage_branch=df_lineage['ground_truth_lineage'].iloc[0],
                start_time_frame=df_lineage['time_frame'].min(),
                experiment_name=df_lineage['experiment_name'].iloc[0],
                fov=df_lineage['FOV'].iloc[0],
                trench_id=df_lineage['trench_id'].iloc[0]
               )
    return data

In [13]:
def generate_local_temporal_negative_samples(data: Data, num_neg_samples_per_pos_edge: float, radius_threshold: float, device='cpu'):
    """
    Generates negative samples by considering only cells in consecutive time frames
    and within a certain spatial radius of potential source nodes, excluding true positives.

    Args:
        data (torch_geometric.data.Data): A single graph batch containing x, edge_index, pos, time_frame.
        num_neg_samples_per_pos_edge (float): Ratio of negative samples to positive samples.
                                                e.g., 1.0 for 1:1, 2.0 for 2:1.
        radius_threshold (float): Maximum spatial distance for a potential negative connection.
        device (str): Device to put tensors on.

    Returns:
        torch.Tensor: edge_index of sampled negative connections, shape [2, num_neg_samples].
    """
    if data.edge_index.numel() == 0: # No positive edges, no negative samples possible this way
        return torch.empty((2, 0), dtype=torch.long, device=device)

    # Convert tensors to CPU for easier numpy/list processing if needed, then back to device
    pos_coords = data.pos.cpu().numpy() # Assuming pos is [num_nodes, 2] (y,x) or [num_nodes, 1] (y)
    time_frames = data.time_frame.cpu().numpy()
    num_nodes = data.num_nodes
    existing_edges = set(tuple(e) for e in data.edge_index.cpu().T.tolist()) # Convert to set for fast lookup

    potential_neg_samples = []

    # Iterate through all possible source nodes
    for i in range(num_nodes):
        current_node_time = time_frames[i]
        current_node_pos = pos_coords[i]

        # Iterate through all possible target nodes (j)
        for j in range(num_nodes):
            # 1. Temporal Constraint: Only consider next time frame
            if time_frames[j] != current_node_time + 1:
                continue

            # 2. Local Constraint: Check spatial proximity (Euclidean distance)
            target_node_pos = pos_coords[j]
            # Adjust distance calculation based on your 'pos' dimension
            if pos_coords.ndim == 1: # If 'pos' is just centroid_y (1D)
                distance = np.abs(current_node_pos - target_node_pos)
            else: # If 'pos' is (y, x) or (x, y) etc. (2D or more)
                distance = np.linalg.norm(current_node_pos - target_node_pos)

            if distance > radius_threshold:
                continue

            # 3. Exclude existing positive edges
            if (i, j) not in existing_edges:
                potential_neg_samples.append((i, j))

    # Convert to tensor
    if not potential_neg_samples:
        return torch.empty((2, 0), dtype=torch.long, device=device), torch.empty((0, data.x.size(1) + 1), dtype=torch.float, device=device)

    potential_neg_samples_tensor = torch.tensor(potential_neg_samples, dtype=torch.long).T.to(device)

    # Sample a subset
    num_positive_edges = data.edge_index.size(1)
    desired_neg_samples = int(num_positive_edges * num_neg_samples_per_pos_edge)

    if desired_neg_samples >= potential_neg_samples_tensor.size(1):
        sampled_neg_edge_index = potential_neg_samples_tensor
    else:
        indices = torch.randperm(potential_neg_samples_tensor.size(1), device=device)[:desired_neg_samples]
        sampled_neg_edge_index = potential_neg_samples_tensor[:, indices]

    # Compute initial edge_attr for the sampled negative edges
    neg_src_nodes_indices = sampled_neg_edge_index[0]
    neg_tgt_nodes_indices = sampled_neg_edge_index[1]

    initial_neg_edge_attr_list = []
    for i in range(sampled_neg_edge_index.size(1)):
        v_i_raw = data.x[neg_src_nodes_indices[i]] # Use original node features (data.x)
        v_j_raw = data.x[neg_tgt_nodes_indices[i]]
        initial_neg_edge_attr_list.append(DS_block(v_i_raw, v_j_raw))

    if initial_neg_edge_attr_list:
        initial_neg_edge_attr = torch.stack(initial_neg_edge_attr_list, dim=0).to(device)
    else:
        initial_neg_edge_attr = torch.empty((0, data.x.size(1) + 1), dtype=torch.float).to(device)

    return sampled_neg_edge_index, initial_neg_edge_attr

In [81]:
# --- Helper MLP for f_PDN_edge (Attention Weights) ---
class PDNEdgeMLP(nn.Module):
    def __init__(self, edge_feature_dim, out_dim=1):
        super().__init__()
        # Simplified MLP for attention weights (scalar output)
        self.mlp = nn.Sequential(
            nn.Linear(edge_feature_dim, 32), # Example hidden dim
            nn.ReLU(),
            nn.Linear(32, out_dim)
        )
    def forward(self, z): # z is edge feature
        return self.mlp(z)

# --- Helper MLP for f_PDN_node (Node Feature Transformation) ---
class PDNNodeMLP(nn.Module):
    def __init__(self, node_feature_dim, out_dim):
        super().__init__()
        # MLP for transforming node features before aggregation
        self.mlp = nn.Sequential(
            nn.Linear(node_feature_dim, out_dim), # Typically out_dim = node_feature_dim
            nn.ReLU()
            # No final ReLU if you want negative values for weighted sum, or add Batch Norm
        )
    def forward(self, x): # x is node feature
        return self.mlp(x)

# --- D-S Block (Distance & Similarity) ---
def DS_block(v_i, v_j):
    """
    Calculates Distance & Similarity vector for two node feature vectors.
    Equivalent to Eq. 3 in the paper.
    Args:
        v_i (torch.Tensor): Feature vector of node i, shape [d_v].
        v_j (torch.Tensor): Feature vector of node j, shape [d_v].
    Returns:
        torch.Tensor: Concatenated vector of absolute differences and cosine similarity, shape [d_v + 1].
    """
    abs_diff = torch.abs(v_i - v_j)
    cosine_similarity = F.cosine_similarity(v_i.unsqueeze(0), v_j.unsqueeze(0)).squeeze(0)
    return torch.cat([abs_diff, cosine_similarity.unsqueeze(0)], dim=-1) # Unsqueeze for scalar cos_sim

# --- The EP-MPNN Block ---
class EP_MPNN_Block(MessagePassing):
    def __init__(self, node_channels, edge_channels):
        super().__init__(aggr='add', flow='source_to_target') # Aggregation for node update. source_to_target for N(i) being t-1 nodes.
        self.node_channels = node_channels
        self.edge_channels = edge_channels

        # Node Feature Update components (PDN-Conv)
        # f_PDN_edge: Maps edge features to scalar attention weights (omega)
        self.f_pdn_edge = PDNEdgeMLP(edge_channels, out_dim=1)
        # f_PDN_node: Transforms node features (tilde_x)
        self.f_pdn_node = PDNNodeMLP(node_channels, node_channels) # Output dim same as input for residuals

        # Edge Feature Update components (Edge Encoder)
        # f_EE_edge: MLP to update edge features.
        # Input: current edge_features (edge_channels)
        #        + updated node_features from source (node_channels)
        #        + updated node_features from target (node_channels)
        #        + D-S block output (node_channels + 1)
        self.f_ee_edge = nn.Sequential(
            nn.Linear(edge_channels + 2 * node_channels + (node_channels + 1), 128), # Example hidden size
            nn.ReLU(),
            nn.Linear(128, edge_channels) # Output dim same as edge_channels
        )

        # BatchNorm (optional but often helpful for stability)
        self.bn_node = nn.BatchNorm1d(node_channels)
        self.bn_edge = nn.BatchNorm1d(edge_channels)

    def forward(self, x, edge_index, edge_attr):
        # x: node features X^(l-1)
        # edge_index: graph connectivity
        # edge_attr: edge features Z^(l-1)

        # 1. Edge Feature Update (first for this block, as per paper's description "In the l-th block vi = x(l)i and vj = x(l)j")
        # However, the paper implies x(l) is used. Let's assume for simplicity first block
        # uses x(l-1) and subsequent blocks use x(l).
        # To align with: "In the l-th block vi = x(l)i and vj = x(l)j." and "the features of an edge ej,i are updated based on the features of νi and νj"
        # This means edge update uses nodes *after* they are potentially updated by previous block.
        # For l=0 (initial), x(0) are raw features. For l>0, x(l) comes from PDN-Conv.
        # For simplicity, let's make it work sequentially: update nodes, THEN update edges using new nodes.
        # Or, as paper implies "alternately updated", meaning within the block loop:
        # Step A: Compute updated nodes x^(l) from x^(l-1) and z^(l-1)
        # Step B: Compute updated edges z^(l) from x^(l) and z^(l-1)
        # Let's follow this:

        # Cache inputs for edge update after node update
        x_prev = x
        edge_attr_prev = edge_attr

        # 2. Node Feature Update (PDN-Conv: Eq. 2)
        # Message passing step
        
        x_updated = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=(x.size(0), x.size(0)))

        # Add residual connection and apply BatchNorm/ReLU
        x = self.bn_node(x_prev + x_updated) # Residual (assuming input and output dims are same)
        x = F.relu(x) # x is now X^(l)

        # 3. Edge Feature Update (Edge Encoder: based on x^(l) and z^(l-1))
        # Get source and target node embeddings for edges
        row, col = edge_index
        src_node_features = x[row] # x^(l) for source nodes
        tgt_node_features = x[col] # x^(l) for target nodes

        # Compute D-S block output for each edge
        ds_outputs = []
        for i in range(edge_attr.size(0)): # Iterate per edge
            ds_outputs.append(DS_block(src_node_features[i], tgt_node_features[i]))
        ds_outputs_tensor = torch.stack(ds_outputs, dim=0) # Shape: [num_edges, node_channels + 1]

        # Concatenate inputs for f_EE_edge
        # current edge_features (Z^(l-1))
        # updated node_features from source (X^(l))
        # updated node_features from target (X^(l))
        # D-S block output (from X^(l), X^(l))
        edge_input_for_mlp = torch.cat([
            edge_attr_prev, # Z^(l-1)
            src_node_features, # X^(l)
            tgt_node_features, # X^(l)
            ds_outputs_tensor # D-S block on X^(l)
        ], dim=-1)

        # Pass through edge encoder MLP
        edge_attr = self.f_ee_edge(edge_input_for_mlp) # Z^(l)
        edge_attr = self.bn_edge(edge_attr) # BatchNorm
        edge_attr = F.relu(edge_attr) # ReLU

        return x, edge_attr # Return updated nodes (X^(l)) and updated edges (Z^(l))

    def message(self, x_j, edge_attr): # x_j is neighbor features, edge_attr_i is edge features to neighbor
        # x_j: x^(l-1)_j (features of neighbor j)
        # edge_attr: z^(l-1)_ji (features of edge (j,i))
        # Compute omega_ji = f_PDN_edge(z_ji) (attention weight for edge j,i)
        omega_ji = self.f_pdn_edge(edge_attr)
        # Compute tilde_x_j = f_PDN_node(x_j) (mapped feature vector of node j)
        tilde_x_j = self.f_pdn_node(x_j)

        # The message is omega_ji * tilde_x_j
        return omega_ji * tilde_x_j

    def aggregate(self, inputs, index, dim_size=None):
        # inputs: [num_messages, hidden_channels] (omega_ji * tilde_x_j for each edge)
        # index: target node index for each message
        # dim_size: total number of nodes
        # Summation aggregation (as per Eq. 2)
        out = super().aggregate(inputs, index, dim_size=dim_size)
        return out

    def update(self, aggr_out):
        # This is where the output of aggregation (sum_j omega_ji * tilde_x_j)
        # is combined with the current node feature.
        # But per Eq. 2, the residual is handled in the forward pass.
        # So we just return the aggregated messages here.
        return aggr_out # This will be the x_updated in the forward pass

In [82]:
class LineageLinkPredictionGNN(nn.Module):
    def __init__(self, in_channels, initial_edge_channels, hidden_channels, num_blocks=2):
        super().__init__()
        self.num_blocks = num_blocks
        self.hidden_channels = hidden_channels

        # Initial Linear layer to project input features to hidden_channels
        self.initial_node_proj = nn.Linear(in_channels, hidden_channels)

        # Initial Edge Feature Projector (optional, if initial_edge_channels is different from hidden_channels)
        # Or you can define specific initial edge features.
        self.initial_edge_proj = nn.Linear(initial_edge_channels, hidden_channels)

        # Stack L EP-MPNN blocks
        self.ep_mpnn_blocks = nn.ModuleList()
        for _ in range(num_blocks):
            self.ep_mpnn_blocks.append(EP_MPNN_Block(hidden_channels, hidden_channels))

        # Decoder for link prediction (takes concatenated node embeddings)
        self.decoder = nn.Sequential(
            nn.Linear(2 * hidden_channels, 64),
            nn.ReLU(),
            nn.Linear(64, 1) # Output a single logit for binary classification
        )

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr # Data now also has edge_attr

        # Initial projection of node features
        x = F.relu(self.initial_node_proj(x))

        # Initial projection of edge features
        edge_attr = F.relu(self.initial_edge_proj(edge_attr)) # Ensure initial_edge_channels maps to hidden_channels

        # Pass through L EP-MPNN blocks
        for block in self.ep_mpnn_blocks:
            x, edge_attr = block(x, edge_index, edge_attr) # Both nodes and edges get updated

        # x are the final node embeddings after L blocks
        return x # Return node embeddings for the decoder

    def decode(self, z, pos_edge_index, neg_edge_index=None):
        # This decode method remains largely the same as before,
        # as it operates on the final node embeddings 'z'.
        edge_indices = torch.cat([pos_edge_index, neg_edge_index], dim=-1) if neg_edge_index is not None else pos_edge_index

        src_embed = z[edge_indices[0]]
        tgt_embed = z[edge_indices[1]]
        edge_features = torch.cat([src_embed, tgt_embed], dim=-1) # Concatenate

        logits = self.decoder(edge_features).squeeze(-1)
        return logits

In [83]:
def train_link_prediction(model, train_loader, optimizer, criterion, device, neg_sample_ratio=1.0, radius_threshold=None):
    model.train()
    total_loss = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()

        neg_edge_index, neg_edge_attr_initial = generate_local_temporal_negative_samples(
            data,
            num_neg_samples_per_pos_edge=neg_sample_ratio,
            radius_threshold=radius_threshold,
            device=device
        )
        if neg_edge_index.numel() == 0:
            print("Warning: No negative samples generated for a batch. Skipping.")
            continue # Skip this batch if no valid negative samples

        z = model(data) # Forward pass returns final node embeddings

        pos_logits = model.decode(z, data.edge_index)
        neg_logits = model.decode(z, neg_edge_index)

        pos_labels = torch.ones(pos_logits.size(0), device=device)
        neg_labels = torch.zeros(neg_logits.size(0), device=device)

        logits = torch.cat([pos_logits, neg_logits])
        labels = torch.cat([pos_labels, neg_labels])

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    return avg_loss

def evaluate_link_prediction(model, loader, criterion, device, neg_sample_ratio=1.0, radius_threshold=None):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for data in loader:
            data = data.to(device)

            # Generate negative edges for evaluation
            neg_edge_index, neg_edge_attr_initial = generate_local_temporal_negative_samples(
                data,
                num_neg_samples_per_pos_edge=neg_sample_ratio,
                radius_threshold=radius_threshold,
                device=device
            )

            if neg_edge_index.numel() == 0:
                print("Warning: No negative samples generated for a batch during evaluation. Skipping.")
                continue

            # Combine positive and negative edges for the GNN's forward pass
            combined_edge_index = torch.cat([data.edge_index, neg_edge_index], dim=1)
            combined_edge_attr_initial = torch.cat([data.edge_attr, neg_edge_attr_initial], dim=0)

            temp_data_for_forward = data.clone()
            temp_data_for_forward.edge_index = combined_edge_index
            temp_data_for_forward.edge_attr = combined_edge_attr_initial

            z = model(temp_data_for_forward) # Pass the combined data for message passing

            pos_logits = model.decode(z, data.edge_index)
            neg_logits = model.decode(z, neg_edge_index)

            pos_labels = torch.ones(pos_logits.size(0), device=device)
            neg_labels = torch.zeros(neg_logits.size(0), device=device)

            logits = torch.cat([pos_logits, neg_logits])
            labels = torch.cat([pos_labels, neg_labels])

            loss = criterion(logits, labels)
            total_loss += loss.item()

            preds = (torch.sigmoid(logits) > 0.5).long() # Convert logits to binary predictions
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    avg_loss = total_loss / len(loader)
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    accuracy = (all_preds == all_labels).float().mean().item()
    # You might also want to calculate precision, recall, F1-score for link prediction
    return avg_loss, accuracy

In [84]:
# Lineage dataset
class LineageDataset(Dataset):
    def __init__(self, data_list):
        super().__init__()
        self.data_list = data_list

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data_list[idx]

In [85]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [86]:
# Define node features
node_feature_cols = ['area', 'centroid_y', 
       'axis_major_length', 'axis_minor_length', 'intensity_mean_phase',
       'intensity_max_phase', 'intensity_min_phase', 'intensity_mean_fluor',
       'intensity_max_fluor', 'intensity_min_fluor']

for col in node_feature_cols:
    df_cells.loc[:,col] = df_cells[col].astype(np.float32);   

In [87]:
def identify_sub_lineage_roots(df):
    # Ensure relevant columns are present
    required_cols = ['experiment_name', 'FOV', 'trench_id', 'ground_truth_lineage', 'time_frame', 'node_id']
    if not all(col in df.columns for col in required_cols):
        raise ValueError(f"DataFrame must contain all required columns: {required_cols}")
    df_temp_sorted = df.sort_values(by=['time_frame', 'node_id'])
    first_appearances = df_temp_sorted.drop_duplicates(
        subset=['experiment_name', 'FOV', 'trench_id', 'ground_truth_lineage'],
        keep='first'
    )

    # Extract the necessary information for each root
    # Convert to list of tuples as in the original function's output format
    sub_lineage_roots = list(first_appearances[[
        'experiment_name',
        'FOV',
        'trench_id',
        'ground_truth_lineage',
        'time_frame'
    ]].itertuples(index=False, name=None))

    return sub_lineage_roots

In [88]:
# create graphs
sub_lineage_roots_tuples = identify_sub_lineage_roots(df_cells)
branch_graphs_list = []

for exp, fov, trench, root_lineage_str, start_t in sub_lineage_roots_tuples:
    # Filter the DataFrame to include only the cells belonging to this specific sub-lineage branch
    # and starting from its first appearance time
    df_branch = df_cells[
        (df_cells['experiment_name'] == exp) &
        (df_cells['FOV'] == fov) &
        (df_cells['trench_id'] == trench) &
        (df_cells['time_frame'] >= start_t) &
        # This regex ensures we only get descendants of this specific root_lineage_str
        (df_cells['ground_truth_lineage'].apply(lambda x: x == root_lineage_str or x.startswith(f"{root_lineage_str}.")))
    ].copy() # Use .copy() to avoid SettingWithCopyWarning

    # Ensure the branch actually has cells, otherwise skip
    if not df_branch.empty:
        #print(f"Processing Sub-Lineage: '{root_lineage_str}' starting at t={start_t} (Exp: {exp}, FOV: {fov}, Trench: {trench})...")
        graph = create_lineage_graph(df_branch, device=device)
        if graph is not None:
            branch_graphs_list.append(graph)
        else:
            print(f"  Skipped (no valid connections/nodes) for '{root_lineage_str}' at t={start_t}")
    else:
        print(f"  Skipped (empty DataFrame) for '{root_lineage_str}' at t={start_t}")


print(f"\nSuccessfully created {len(branch_graphs_list)} PyG graphs, one per identified sub-lineage branch.")


Successfully created 34 PyG graphs, one per identified sub-lineage branch.


In [89]:
if branch_graphs_list:
    print("\nExample Sub-Lineage Graph Details:")
    for i, graph in enumerate(branch_graphs_list[:3]): # Print details for first 3 graphs
        print(f"\n--- Graph {i+1} ---")
        print(graph)
        print(f"  Nodes (x.shape): {graph.x.shape}")
        print(f"  Edges (edge_index.shape): {graph.edge_index.shape}")
        print(f"  Labels (y.shape): {graph.y.shape}")
        print(f"  Root Lineage Branch ID: {graph.root_lineage_branch}")
        print(f"  Start Time Frame: {graph.start_time_frame}")
        print(f"  Experiment Name: {graph.experiment_name}")
        print(f"  FOV: {graph.fov}")
        print(f"  Trench ID: {graph.trench_id}")
        # Optionally, print the actual ground truth lineage IDs in this subgraph for verification
        # print(f"  Included Lineages: {df_cells.loc[graph.node_id.cpu().numpy()]['ground_truth_lineage'].unique()}")


Example Sub-Lineage Graph Details:

--- Graph 1 ---
Data(x=[359, 10], edge_index=[2, 342], edge_attr=[342, 11], y=[359], pos=[359, 1], num_nodes=359, time_frame=[359], root_lineage_branch='A.1.1.2', start_time_frame=0, experiment_name='DUMM_giTG69_Glucose_013025', fov='007', trench_id='295')
  Nodes (x.shape): torch.Size([359, 10])
  Edges (edge_index.shape): torch.Size([2, 342])
  Labels (y.shape): torch.Size([359])
  Root Lineage Branch ID: A.1.1.2
  Start Time Frame: 0
  Experiment Name: DUMM_giTG69_Glucose_013025
  FOV: 007
  Trench ID: 295

--- Graph 2 ---
Data(x=[267, 10], edge_index=[2, 259], edge_attr=[259, 11], y=[267], pos=[267, 1], num_nodes=267, time_frame=[267], root_lineage_branch='A.2.1.2.1', start_time_frame=1, experiment_name='DUMM_giTG69_Glucose_013025', fov='007', trench_id='295')
  Nodes (x.shape): torch.Size([267, 10])
  Edges (edge_index.shape): torch.Size([2, 259])
  Labels (y.shape): torch.Size([267])
  Root Lineage Branch ID: A.2.1.2.1
  Start Time Frame: 1
  

In [90]:
# Define split ratios
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15 

# Ensure ratios sum to 1
assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1."

# Sequentially split data because train_test_split can only split into 2 sets
train_graphs, temp_graphs = train_test_split(
    branch_graphs_list,
    test_size=(val_ratio + test_ratio), # Combined size for validation and test
    random_state=0 # set seed for reproducibility
)

# split temp_graphs into validation and test sets
# (test_ratio / (val_ratio + test_ratio)) ensures the correct proportion from the temporary set
val_graphs, test_graphs = train_test_split(
    temp_graphs,
    test_size=(test_ratio / (val_ratio + test_ratio)),
    random_state=0 # set seed for reproducibility
)

print(f"Total number of graphs: {len(branch_graphs_list)}")
print(f"Number of training graphs: {len(train_graphs)}")
print(f"Number of validation graphs: {len(val_graphs)}")
print(f"Number of test graphs: {len(test_graphs)}")

Total number of graphs: 34
Number of training graphs: 23
Number of validation graphs: 5
Number of test graphs: 6


In [91]:
train_dataset = LineageDataset(train_graphs)
val_dataset = LineageDataset(val_graphs)
test_dataset = LineageDataset(test_graphs)

print("\nPyG Datasets created:")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)


PyG Datasets created:
Train dataset size: 23
Validation dataset size: 5
Test dataset size: 6


In [92]:
# Determine number of node features in dataset
num_node_features = len(node_feature_cols)

# Determine the dimensionality of the initial edge features
# As per your DS_block, it's len(node_feature_cols) + 1 (for cosine similarity)
initial_edge_feature_dim = len(node_feature_cols) + 1

# Instantiate the model with the new parameters

model = LineageLinkPredictionGNN(
    in_channels=num_node_features,
    initial_edge_channels=initial_edge_feature_dim, # <-- NEW REQUIRED ARGUMENT
    hidden_channels=64, # tune (e.g., 16, 32, 128, 256)
    num_blocks=2 # <-- NEW OPTIONAL ARGUMENT, paper suggests L blocks
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss() # For binary classification of edges

In [93]:
# --- Define Hyperparameters and Training Configuration ---
epochs = 200 # You need to set the total number of training epochs
learning_rate = 0.001
hidden_channels = 128 # Example: for your GNN model
initial_node_channels = 10 # Example: based on len(node_feature_cols)
initial_edge_channels = 11 # Example: based on the output of your DS_block

In [None]:
# Define your radius_threshold based on your data's spatial scale
my_radius_threshold = 50.0 # Example value, adjust this for your data!

print("\nStarting link prediction training with improved negative sampling...")
for epoch in range(1, epochs + 1):
    train_loss = train_link_prediction(model, train_loader, optimizer, criterion, device, neg_sample_ratio=1.0, radius_threshold=my_radius_threshold)
    val_loss, val_acc = evaluate_link_prediction(model, val_loader, criterion, device, neg_sample_ratio=1.0, radius_threshold=my_radius_threshold)

    print(f'Epoch: {epoch:03d}, '
          f'Train Loss: {train_loss:.4f}, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

# --- Final Evaluation on Test Set ---
print("\nEvaluating on test set with improved negative sampling...")
test_loss, test_acc = evaluate_link_prediction(model, test_loader, criterion, device, neg_sample_ratio=1.0, radius_threshold=my_radius_threshold)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')


Starting link prediction training with improved negative sampling...
Epoch: 001, Train Loss: 0.6963, Val Loss: 2.1442, Val Acc: 0.1910
Epoch: 002, Train Loss: 0.6947, Val Loss: 1.3850, Val Acc: 0.1910
Epoch: 003, Train Loss: 0.6839, Val Loss: 2.0352, Val Acc: 0.1910
Epoch: 004, Train Loss: 0.6796, Val Loss: 1.2075, Val Acc: 0.1910
Epoch: 005, Train Loss: 0.6752, Val Loss: 0.7974, Val Acc: 0.3568
Epoch: 006, Train Loss: 0.6704, Val Loss: 0.8771, Val Acc: 0.2563
Epoch: 007, Train Loss: 0.6601, Val Loss: 1.0731, Val Acc: 0.2060
Epoch: 008, Train Loss: 0.6573, Val Loss: 1.0615, Val Acc: 0.2060
Epoch: 009, Train Loss: 0.6510, Val Loss: 1.0568, Val Acc: 0.2563
Epoch: 010, Train Loss: 0.6467, Val Loss: 1.0725, Val Acc: 0.2312
Epoch: 011, Train Loss: 0.6447, Val Loss: 1.7366, Val Acc: 0.1910
Epoch: 012, Train Loss: 0.6433, Val Loss: 1.2357, Val Acc: 0.1910
Epoch: 013, Train Loss: 0.6303, Val Loss: 1.1381, Val Acc: 0.2010
Epoch: 014, Train Loss: 0.6346, Val Loss: 1.0771, Val Acc: 0.2010
Epoch: