In [1]:
# Imports
import os
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import topo as tp
from scipy.sparse import csr_matrix

import plotly.graph_objects as go
from ipywidgets import widgets, VBox, HBox, Layout
from IPython.display import display
from scipy.spatial import ConvexHull
from scipy.ndimage import gaussian_filter
import tifffile
import matplotlib.pyplot as plt
from typing import List, Tuple

# Try GPU UMAP (cuML), fallback to CPU UMAP
try:
    from cuml import UMAP as GPU_UMAP
    use_gpu = True
    print("Using GPU-accelerated UMAP (cuML)")
except ImportError:
    from umap import UMAP as CPU_UMAP
    use_gpu = False
    print("cuML not found, falling back to CPU UMAP")

Using GPU-accelerated UMAP (cuML)


In [2]:
RUN_NAME = 'Exp06_Site10'
EXP_NAME = 'Exp06_Site10'
BASE_DIR = '/home/nbahou/myimaging/apoDet_refactored'
INPUT_FILE = f"{EXP_NAME}_features.jsonl"
try:
    #df_emb_list = pd.read_csv(INPUT_FILE)
    df_emb_list = pd.read_json(
        INPUT_FILE,
        orient='records',
        lines=True
    )
    print(f"DataFrame loaded successfully. Shape: {df_emb_list.shape}")
except FileNotFoundError:
    print(f"Error: Could not find '{INPUT_FILE}'. Please check the file path.")
    # You would typically stop execution here if the file is critical

# --- 2. Prepare Data for UMAP (Convert to NumPy array 'feats') ---
# CRITICAL STEP: The 'embedding' column is read as a string in CSV. 
# We must convert this string representation back into a list of floats (eval/literal_eval).
# Using ast.literal_eval is safer than eval(), but we'll use a robust pandas function here.
# Apply literal_eval to safely convert the string column back to a list of floats

# Convert the list-of-lists column into the NumPy array UMAP expects.
feats = np.array(df_emb_list['embedding'].tolist())


DataFrame loaded successfully. Shape: (295414, 6)


In [3]:
# --- New Step 1.7: Add Apoptosis Annotation Feature ---

APOPTOSIS_FILE = f"{BASE_DIR}/data/{RUN_NAME}/apo_match_csv/{EXP_NAME}.csv" 
APOPTOSIS_LABEL = 'apo'
NORMAL_LABEL = 'non-apo'

try:
    # 1. Load the apoptosis data
    df_apo = pd.read_csv(APOPTOSIS_FILE)
    print(f"\nLoaded {df_apo.shape[0]} apoptosis annotations.")
    
    # We only need the track_id and the corrected time of apoptosis (correct_t)
    # The 'matching_track' column in your data corresponds to the track_id in df_emb_list
    df_apo = df_apo[['matching_track', 'correct_t']].rename(
        columns={'matching_track': 'track_id', 'correct_t': 't_apoptosis'}
    ).drop_duplicates(subset=['track_id']) # Ensure each track has only one apoptosis time

    # 2. Initialize the new feature column in the main DataFrame
    df_emb_list['phenotype'] = NORMAL_LABEL

    # 3. Merge the apoptosis time into the main DataFrame
    # This adds the t_apoptosis time to every row of the matching track
    df_emb_list = pd.merge(
        df_emb_list, 
        df_apo, 
        on='track_id', 
        how='left'
    )
    
    # 4. Conditionally assign the 'Apoptotic' label
    # A track point is 'Apoptotic' if:
    # a) It belongs to a track that eventually undergoes apoptosis (t_apoptosis is not NaN)
    # b) Its current time ('t') is GREATER THAN or EQUAL TO the time of apoptosis ('t_apoptosis')
    df_emb_list['phenotype'] = np.where(
        df_emb_list['t'] >= df_emb_list['t_apoptosis'],
        APOPTOSIS_LABEL,
        df_emb_list['phenotype']
    )
    
    # 5. Clean up temporary column
    # We keep 't_apoptosis' for debugging, but you could drop it here if needed.
    # df_emb_list = df_emb_list.drop(columns=['t_apoptosis'])

    # Display how many points were labeled
    apo_count = (df_emb_list['phenotype'] == APOPTOSIS_LABEL).sum()
    total_count = df_emb_list.shape[0]
    print(f"Labeled {apo_count} / {total_count} track points as '{APOPTOSIS_LABEL}'.")
    
except FileNotFoundError:
    print(f"Warning: Apoptosis file '{APOPTOSIS_FILE}' not found. Skipping phenotype annotation.")
    if 'phenotype' not in df_emb_list.columns:
         df_emb_list['phenotype'] = NORMAL_LABEL



Loaded 749 apoptosis annotations.
Labeled 13412 / 295414 track points as 'apo'.


In [4]:
def compute_embedding(feats, use_gpu=True, topometry=False, n_neighbors=50, n_components=2, min_dist=0.1, random_state=42):
    """
    Compute 2D embedding using UMAP, optionally enhanced with TopOMetry.
    
    Parameters
    ----------
    feats : np.ndarray or pd.DataFrame
        Feature matrix (cells x features)
    use_gpu : bool
        Whether to use GPU UMAP (cuml) or CPU UMAP (umap-learn)
    topometry : bool
        Whether to apply TopOMetry graph enhancement first
    n_neighbors : int
        Number of neighbors for UMAP or TopOMetry graph
    n_components : int
        Number of UMAP dimensions
    min_dist : float
        UMAP min_dist parameter
    random_state : int
        Random seed for reproducibility
        
    Returns
    -------
    embedding : np.ndarray
        2D embedding of shape (n_samples, n_components)
    """
    # Optionally apply TopOMetry
    if topometry:
        print("Applying TopOMetry to enhance graph...")
        tg = tp.TopOGraph(
            n_eigs=50,
            base_knn=n_neighbors,
            graph_knn=n_neighbors,
            random_state=random_state,
            n_jobs=-1
        )
        tg.fit(feats)

        tg.transform(feats)
        
        # Extract the first kernel graph (similarity matrix)
        kernel_obj = list(tg.GraphKernelDict.values())[0]

        # Convert to distance: D = 1 - normalized_K
        K = kernel_obj.K.astype(float)
        # normalize to [0,1] if not already
        K = K / K.max()
        
        D = K.copy()                  # copy sparse matrix
        D.data = 1.0 - D.data         # subtract only the non-zero entries
        D.setdiag(0)
        D.eliminate_zeros() # Good practice to clean up
        D = D.toarray()
        
        data_for_umap = D
        metric = "precomputed"
        UMAP_Class = CPU_UMAP
    else:
        print("Running UMAP directly on features...")
        data_for_umap = feats
        metric = "cosine"
        UMAP_Class = GPU_UMAP if use_gpu else CPU_UMAP
    
    
    # Instantiate UMAP reducer
    reducer = UMAP_Class(
        n_components=n_components,
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        metric=metric,
        random_state=random_state
    )

    

    # Fit and transform
    embedding = reducer.fit_transform(data_for_umap)
    print(f"Computed UMAP embedding of shape: {embedding.shape}")
    
    return embedding

embedding = compute_embedding(feats, use_gpu=use_gpu, topometry=False)

Running UMAP directly on features...
[2025-11-19 13:12:17.263] [CUML] [info] build_algo set to brute_force_knn because random_state is given
Computed UMAP embedding of shape: (295414, 2)


In [5]:
# 'df_emb_list' now only contains metadata (t, x, y, track_id)
df_emb_list['umap_1'] = embedding[:, 0]
df_emb_list['umap_2'] = embedding[:, 1]

print("\nUMAP coordinates added to df_emb_list.")


UMAP coordinates added to df_emb_list.


In [6]:
# --- Configuration ---
SAMPLE_MODE = "sample_per_track"   # options: "sample_tracks" or "sample_per_track"
N_TRACKS_TO_SAMPLE = 200        # used only in "sample_tracks" mode
N_SAMPLES_PER_TRACK = 50        # used only in "sample_per_track" mode

# --- Subsetting Logic ---
unique_track_ids = df_emb_list['track_id'].unique()

if SAMPLE_MODE == "sample_tracks":
    # --- Mode 1: sample a subset of tracks ---
    n_actual_sample = min(N_TRACKS_TO_SAMPLE, len(unique_track_ids))
    sampled_ids = pd.Series(unique_track_ids).sample(
        n=n_actual_sample,
        random_state=42
    ).tolist()
    df_subset = df_emb_list[df_emb_list['track_id'].isin(sampled_ids)].copy()

elif SAMPLE_MODE == "sample_per_track":
    # --- Mode 2: sample n examples per track ---
    df_subset = (
        df_emb_list
        .groupby("track_id", group_keys=False)
        .apply(lambda g: g.sample(n=min(N_SAMPLES_PER_TRACK, len(g)), random_state=42))
        .reset_index(drop=True)
    )

else:
    raise ValueError(f"Invalid SAMPLE_MODE '{SAMPLE_MODE}'. Choose 'sample_tracks' or 'sample_per_track'.")

# --- Summary ---
print(f"Sampling mode: {SAMPLE_MODE}")
print(f"Original DataFrame Size: {df_emb_list.shape[0]} rows")
print(f"Subset DataFrame Size: {df_subset.shape[0]} rows")
print(f"Total Unique Tracks in Subset: {df_subset['track_id'].nunique()}")


Sampling mode: sample_per_track
Original DataFrame Size: 295414 rows
Subset DataFrame Size: 106636 rows
Total Unique Tracks in Subset: 2751


  .apply(lambda g: g.sample(n=min(N_SAMPLES_PER_TRACK, len(g)), random_state=42))


In [7]:
import plotly.graph_objects as go
import plotly.colors
from ipywidgets import widgets, VBox, HBox, Layout
from IPython.display import display
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import ConvexHull
from scipy.ndimage import gaussian_filter
import numpy as np
import pandas as pd
import tifffile
import matplotlib.pyplot as plt
import os
from typing import List, Tuple

# ============================================================================
# CONFIGURATION & DATA PREPARATION
# ============================================================================

image_paths = df_emb_list['path'].to_list()
class_labels = df_emb_list['phenotype'].to_list()
embedding = df_emb_list[['umap_1', 'umap_2']].values

# Prepare track ID options
all_track_ids = sorted(df_emb_list['track_id'].unique().astype(int))

# Initialize nearest neighbors model
nbrs = NearestNeighbors(n_neighbors=20, algorithm='kd_tree').fit(embedding)

# Add new column for manual labels
if 'label_manual' not in df_emb_list.columns:
    df_emb_list['label_manual'] = None

# Track current filtered dataframe and selected tracks
filtered_df_emb_list = df_emb_list.copy()
filtered_embedding = embedding.copy()
selected_tracks = []  # Can select multiple tracks for comparison
annotations = []  # Store user annotations

point_opacity = 0.7

print(f"Total dataset size: {len(df_emb_list):,} cells")
print(f"Unique phenotypes: {df_emb_list['phenotype'].nunique()}")
print(f"Unique tracks: {df_emb_list['track_id'].nunique()}")

# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def get_nearest(x0: float, y0: float, k: int = 20) -> List[Tuple]:
    """Find k nearest neighbors to clicked point."""
    distances, indices = nbrs.kneighbors([[x0, y0]], n_neighbors=k)
    idx_list = indices[0]
    
    results = []
    for j, i in enumerate(idx_list):
        row = filtered_df_emb_list.iloc[i]
        results.append((
            row['path'],
            filtered_embedding[i],
            distances[0][j],
            row['phenotype']
        ))
    return results


def compute_density_contours(df: pd.DataFrame, n_levels: int = 5, margin: float = 0.05):
    if len(df) < 10:
        return None

    x = df['umap_1'].values
    y = df['umap_2'].values

    x_range = x.max() - x.min()
    y_range = y.max() - y.min()
    x_min, x_max = x.min() - margin * x_range, x.max() + margin * x_range
    y_min, y_max = y.min() - margin * y_range, y.max() + margin * y_range

    H, xedges, yedges = np.histogram2d(x, y, bins=50, range=[[x_min, x_max], [y_min, y_max]])
    H = gaussian_filter(H, sigma=1.5)
    levels = np.linspace(H.min(), H.max(), n_levels + 2)[1:-1]

    return H.T, xedges, yedges, levels


def compute_density_per_phenotype(df, n_levels=5, margin: float = 0.05):
    contours = {}
    for pheno in df['phenotype'].unique():
        sub = df[df['phenotype'] == pheno]
        if len(sub) < 10:
            continue

        x = sub['umap_1'].values
        y = sub['umap_2'].values

        # ---- SAME MARGIN ----
        x_range = x.max() - x.min()
        y_range = y.max() - y.min()
        x_min, x_max = x.min() - margin * x_range, x.max() + margin * x_range
        y_min, y_max = y.min() - margin * y_range, y.max() + margin * y_range

        H, xedges, yedges = np.histogram2d(x, y, bins=50, range=[[x_min, x_max], [y_min, y_max]])
        H = gaussian_filter(H, sigma=1.5)
        levels = np.linspace(H.min(), H.max(), n_levels + 2)[1:-1]
        contours[pheno] = (H.T, xedges, yedges, levels)
    return contours


def compute_convex_hulls(df: pd.DataFrame):
    """Compute convex hulls for each phenotype."""
    hulls = {}
    for phenotype in df['phenotype'].unique():
        subset = df[df['phenotype'] == phenotype]
        if len(subset) >= 3:
            points = subset[['umap_1', 'umap_2']].values
            try:
                hull = ConvexHull(points)
                hulls[phenotype] = points[hull.vertices]
            except:
                pass
    return hulls


def show_neighbor_channels_grid(
    neighbors: List[Tuple], 
    k: int = 5, 
    figsize: Tuple[int, int] = (10, 10),
    neighbor_names: List[str] = None
):
    """Display grid of multi-channel cell images for nearest neighbors."""
    k = min(k, len(neighbors))
    if k == 0:
        print("Neighbor list is empty.")
        return

    plt.close('all')
    
    fig, axes = plt.subplots(k, 5, figsize=figsize)
    if k == 1:
        axes = np.array([axes])

    for i in range(k):
        path = neighbors[i][0]
        dist = neighbors[i][2]
        classification = neighbors[i][3]

        try:
            img = tifffile.imread(path)
        except FileNotFoundError:
            print(f"Error: File not found at {path}")
            for j in range(5):
                ax = axes[i, j]
                ax.imshow(np.zeros((128, 128)), cmap="gray")
                ax.axis('off')
                if j == 2:
                    ax.text(64, 64, 'File Not Found', color='white', 
                           ha='center', va='center')
            continue

        if img.ndim == 3 and img.shape[0] == 5:
            img = img.transpose(1, 2, 0)

        if img.shape[-1] != 5:
            raise ValueError(f"Image at {path} does not have 5 channels in last dimension.")

        for j in range(5):
            ax = axes[i, j]
            channel_image = img[..., j]
            ax.imshow(channel_image, cmap='gray', aspect='equal')
            ax.axis('off')

        if i == 0:
            for j in range(5):
                axes[i, j].set_title(f"Channel {j+1}", fontsize=10)

        if neighbor_names and i < len(neighbor_names):
            row_label_name = neighbor_names[i]
        else:
            row_label_name = f"Neighbor {i+1}"

        full_row_label = f"{row_label_name}\nClass: {classification}\nd={dist:.2f}"
        axes[i, 0].text(-0.3, 0.5, full_row_label, 
                       transform=axes[i, 0].transAxes,
                       fontsize=9, ha='right', va='center')

    plt.tight_layout(pad=0.5, h_pad=1.0)
    plt.show()


def update_neighbors_model(df: pd.DataFrame):
    """Rebuild nearest neighbors model when data is filtered."""
    global nbrs, filtered_embedding
    
    if len(df) == 0:
        print("Warning: Filtered dataset is empty!")
        return
    
    filtered_embedding = df[['umap_1', 'umap_2']].values
    k_max = min(20, len(df))
    nbrs = NearestNeighbors(n_neighbors=k_max, algorithm='kd_tree').fit(filtered_embedding)
    print(f"Nearest neighbors model updated with {len(df):,} points")


def get_track_color(idx, n_tracks):
    """Get a distinct color for each track."""
    colors = ['red', 'blue', 'green', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
    return colors[idx % len(colors)]


def update_plot(df_to_plot: pd.DataFrame):
    """Clear and rebuild plot traces based on filtered DataFrame."""
    global fig
    
    # Clear existing traces and annotations
    fig.data = []
    fig.layout.annotations = []
    
    if len(df_to_plot) == 0:
        info_text.value = "‚ö†Ô∏è No data matches current filters"
        return
    
    # Determine filtering mode
    filter_to_tracks = track_filter_checkbox.value and len(selected_tracks) > 0
    
    # Prepare data based on filter mode
    if filter_to_tracks:
        # Only show selected tracks
        df_display = df_to_plot[df_to_plot['track_id'].isin(selected_tracks)].copy()
        df_background = pd.DataFrame()
    else:
        # Show all data, but highlight selected tracks
        df_display = df_to_plot.copy()
        df_background = df_to_plot[~df_to_plot['track_id'].isin(selected_tracks)] if len(selected_tracks) > 0 else df_to_plot
    
    # Add background points (dimmed if tracks are selected)
    if len(df_background) > 0 and len(selected_tracks) > 0:
        df_bg = df_background.copy()

        # ---- NEW: Filter to labeled only? ----
        if show_only_labeled_checkbox.value:
            df_bg = df_bg[df_bg['label_manual'].notna()]

        

        # ---- Sort non-apo on top (only for phenotype) ----
        if not use_manual_labels_checkbox.value:
            df_bg = df_bg.sort_values(
                by='phenotype',
                key=lambda x: x == 'apo'
            )

        color_col = 'label_manual' if use_manual_labels_checkbox.value else 'phenotype'
        labels = df_bg[color_col].fillna('unlabeled')

        # ---- Color map ----
        unique_labels = labels.unique()
        colors = plotly.colors.qualitative.Plotly
        color_map = {lbl: colors[i % len(colors)] for i, lbl in enumerate(unique_labels)}
        color_map['unlabeled'] = 'lightgray'
        point_colors = [color_map[lbl] for lbl in labels]

        # ---- Customdata & hover ----
        has_t = 't' in df_bg.columns
        custom_cols = ['track_id', 't', color_col] if has_t else ['track_id', color_col]
        customdata = df_bg[custom_cols].values
        hover_idx = 2 if has_t else 1
    
        scatter = go.Scattergl(
            x=df_bg['umap_1'],
            y=df_bg['umap_2'],
            mode='markers',
            name='Background',
            marker=dict(size=4, opacity=point_opacity * 0.35, color=point_colors),
            text=df_bg['path'],
            customdata=customdata,
            ids=df_bg.index,
            hovertemplate=(
                f'<b>{color_col.capitalize()}:</b> %{{customdata[{hover_idx}]}}<br>'
                '<b>UMAP 1:</b> %{x:.2f}<br>'
                '<b>UMAP 2:</b> %{y:.2f}<br>'
                + ('<b>Track:</b> %{customdata[0]}<br><b>Frame:</b> %{customdata[1]}<br>'
                   if has_t else '<b>Track:</b> %{customdata[0]}<br>') +
                '<extra></extra>'
            ),
            showlegend=False
        )
        fig.add_trace(scatter)
    
        # ---- Legend (only show real labels) ----
        for lbl in unique_labels:
            if lbl == 'unlabeled' and not use_manual_labels_checkbox.value:
                continue
            fig.add_trace(go.Scattergl(
                x=[None], y=[None],
                mode='markers',
                marker=dict(color=color_map[lbl], size=10),
                name=str(lbl),
                showlegend=True
            ))
    
    # Add selected tracks or all data if no tracks selected
    if len(selected_tracks) > 0:
        # Show each selected track with its own color
        for idx, track_id in enumerate(selected_tracks):
            track_data = df_display[df_display['track_id'] == track_id].copy()
            
            if len(track_data) == 0:
                continue
            
            track_color = get_track_color(idx, len(selected_tracks))
            
            # Add scatter points
            scatter = go.Scattergl(
                x=track_data['umap_1'],
                y=track_data['umap_2'],
                mode='markers',
                name=f'Track {track_id}',
                marker=dict(size=8, opacity=0.9, color=track_color, line=dict(width=1, color='white')),
                text=track_data['path'],
                customdata=track_data[['track_id', 't', 'phenotype']].values if 't' in track_data.columns else None,
                ids=track_data.index,
                hovertemplate=(
                    '<b>Track:</b> %{customdata[0]}<br>'
                    + ('<b>Frame:</b> %{customdata[1]}<br>' if 't' in track_data.columns else '') +
                    ('<b>Phenotype:</b> %{customdata[2]}<br>' if 't' in track_data.columns else '') +
                    '<b>UMAP 1:</b> %{x:.2f}<br>'
                    '<b>UMAP 2:</b> %{y:.2f}<br>'
                    '<extra></extra>'
                )
            )
            fig.add_trace(scatter)
            
            # Add trajectory if requested
            if track_connect_checkbox.value and 't' in track_data.columns and len(track_data) > 1:
                track_data = track_data.sort_values('t')
                
                # Choose coloring mode
                if trajectory_color_mode.value == 'Time':
                    marker_colors = track_data['t'].values
                    colorscale = 'Viridis'
                    colorbar_title = "Frame"
                elif trajectory_color_mode.value == 'Phenotype':
                    discrete_colors = ['#1f77b4', '#ff7f0e'] 
                    unique_phenotypes = track_data['phenotype'].unique()
                    if len(unique_phenotypes) <= len(discrete_colors):
                        phenotype_map = {p: discrete_colors[i] for i, p in enumerate(unique_phenotypes)}
                        marker_colors = track_data['phenotype'].map(phenotype_map).values
                        colorscale = None 
                    else:
                        phenotype_map = {p: i for i, p in enumerate(unique_phenotypes)}
                        marker_colors = track_data['phenotype'].map(phenotype_map).values
                        colorscale = 'Rainbow'
                    colorbar_title = "Phenotype"
                else:  # Track color
                    marker_colors = track_color
                    colorscale = None
                    colorbar_title = None
                
                line_trace = go.Scattergl(
                    x=track_data['umap_1'],
                    y=track_data['umap_2'],
                    mode='lines+markers',
                    line=dict(color=track_color, width=2),
                    marker=dict(
                        size=10,
                        color=marker_colors,
                        colorscale=colorscale,
                        showscale=(colorscale is not None and idx == 0),
                        colorbar=dict(title=colorbar_title, x=1.15) if colorscale else None,
                        line=dict(width=1, color='white')
                    ),
                    customdata=track_data[['t', 'phenotype']].values if 't' in track_data.columns else None,
                    ids=track_data.index,
                    hovertemplate=(
                        '<b>Track:</b> ' + str(track_id) + '<br>'
                        '<b>Frame:</b> %{customdata[0]}<br>'
                        '<b>Phenotype:</b> %{customdata[1]}<br>'
                        '<b>UMAP 1:</b> %{x:.2f}<br>'
                        '<b>UMAP 2:</b> %{y:.2f}<br>'
                        '<extra></extra>'
                    ),
                    name=f'Trajectory {track_id}',
                    showlegend=False
                )
                fig.add_trace(line_trace)
    
    # If no tracks selected, show regular phenotype view as a single trace
    elif len(selected_tracks) == 0:
        df_display = df_to_plot.copy()

        # ---- NEW: Show only labeled cells? ----
        if show_only_labeled_checkbox.value:
            df_display = df_display[df_display['label_manual'].notna()]

        

        # Sort so non-apo on top (only if using phenotype)
        if not use_manual_labels_checkbox.value:
            df_display = df_display.sort_values(
                by='phenotype',
                key=lambda x: x == 'apo'
            )

        color_col = 'label_manual' if use_manual_labels_checkbox.value else 'phenotype'
        labels = df_display[color_col].fillna('unlabeled')  # NA ‚Üí gray

        # Build color map
        unique_labels = labels.unique()
        colors = plotly.colors.qualitative.Plotly
        color_map = {lbl: colors[i % len(colors)] for i, lbl in enumerate(unique_labels)}
        color_map['unlabeled'] = 'lightgray'

        point_colors = [color_map[lbl] for lbl in labels]

        has_t = 't' in df_display.columns
        custom_cols = ['track_id', 't', color_col] if has_t else ['track_id', color_col]
        customdata = df_display[custom_cols].values
        hover_idx = 2 if has_t else 1

        
        scatter = go.Scattergl(
            x=df_display['umap_1'],
            y=df_display['umap_2'],
            mode='markers',
            marker=dict(size=6, opacity=point_opacity, color=point_colors),
            text=df_display['path'],
            customdata=customdata,
            ids=df_display.index,
            hovertemplate=(
                f'<b>{color_col.capitalize()}:</b> %{{customdata[{hover_idx}]}}<br>'
                '<b>UMAP 1:</b> %{x:.2f}<br>'
                '<b>UMAP 2:</b> %{y:.2f}<br>'
                + ('<b>Track:</b> %{customdata[0]}<br><b>Frame:</b> %{customdata[1]}<br>'
                   if has_t else '<b>Track:</b> %{customdata[0]}<br>') +
                '<extra></extra>'
            ),
            name='All cells',
            showlegend=False
        )
        fig.add_trace(scatter)
    
        # Legend
        for lbl in unique_labels:
            if lbl == 'unlabeled' and not use_manual_labels_checkbox.value:
                continue  # hide gray from legend in phenotype mode
            fig.add_trace(go.Scattergl(
                x=[None], y=[None],
                mode='markers',
                marker=dict(color=color_map[lbl], size=10),
                name=str(lbl),
                showlegend=True
            ))
    
    # Add density contours if requested
    if show_density_checkbox.value:
        # ---- 1. Per-phenotype contours (optional) ----
        if per_phenotype_density_checkbox.value:
            per_pheno = compute_density_per_phenotype(df_to_plot)
            # colour map is the same one we used for the points
            # (it is built a few lines above in the background / main block)
            # ‚Üí reuse the variable `color_map` that already exists there
            for pheno, (H, xedges, yedges, levels) in per_pheno.items():
                fig.add_trace(go.Contour(
                    z=H,
                    x=xedges[:-1],
                    y=yedges[:-1],
                    contours=dict(
                        start=levels[0],
                        end=levels[-1],
                        size=(levels[-1] - levels[0]) / len(levels),
                        coloring='none'
                    ),
                    line=dict(width=2.5, color=color_map.get(pheno, 'gray'), dash='dot'),
                    name=f'{pheno} density',
                    showlegend=False,
                    hoverinfo='skip'
                ))
    
        # ---- 2. Global contour (fallback) ----
        else:
            result = compute_density_contours(df_to_plot)
            if result:
                H, xedges, yedges, levels = result
                fig.add_trace(go.Contour(
                    z=H,
                    x=xedges[:-1],
                    y=yedges[:-1],
                    contours=dict(
                        start=levels[0],
                        end=levels[-1],
                        size=(levels[-1] - levels[0]) / len(levels),
                        coloring='none'
                    ),
                    line=dict(width=2, color='rgba(0,0,0,0.3)'),
                    showscale=False,
                    name='Density',
                    hoverinfo='skip'
                ))
    
    # Add convex hulls if requested
    if show_hulls_checkbox.value and len(selected_tracks) == 0:
        hulls = compute_convex_hulls(df_to_plot)
        for phenotype, hull_points in hulls.items():
            # Close the hull
            hull_points_closed = np.vstack([hull_points, hull_points[0]])
            
            fig.add_trace(go.Scattergl(
                x=hull_points_closed[:, 0],
                y=hull_points_closed[:, 1],
                mode='lines',
                line=dict(width=2, dash='dash'),
                name=f'{phenotype} hull',
                showlegend=False,
                hoverinfo='skip'
            ))
    
    # Add user annotations
    for annotation in annotations:
        fig.add_annotation(
            x=annotation['x'],
            y=annotation['y'],
            text=annotation['text'],
            showarrow=True,
            arrowhead=2,
            arrowsize=1,
            arrowwidth=2,
            arrowcolor='red',
            bgcolor='rgba(255,255,255,0.8)',
            bordercolor='red',
            borderwidth=2
        )
    
    # Re-attach click handlers to all traces
    fig.for_each_trace(lambda trace: trace.on_click(on_click_point))
    
    # Enable selection for scatter traces
    for trace in fig.data:
        if isinstance(trace, go.Scattergl) and hasattr(trace, 'on_selection'):
            trace.on_selection(on_lasso_select)
    
    # Update info display
    n_points = len(df_to_plot)
    n_classes = df_to_plot['phenotype'].nunique()
    n_tracks = df_to_plot['track_id'].nunique() if 'track_id' in df_to_plot.columns else 0
    
    if len(selected_tracks) > 0:
        info_text.value = f"üìä Displaying {n_points:,} cells | {n_classes} phenotypes | {n_tracks} tracks | üéØ {len(selected_tracks)} selected"
    else:
        info_text.value = f"üìä Displaying {n_points:,} cells | {n_classes} phenotypes | {n_tracks} tracks"


# ============================================================================
# EVENT HANDLERS
# ============================================================================

def on_click_point(trace, points, state):
    """Handle click events on UMAP scatter plot."""
    if not points.point_inds:
        return
    
    clicked_x = points.xs[0]
    clicked_y = points.ys[0]
    coords_display.value = f'x: {clicked_x:.2f}, y: {clicked_y:.2f}'

    # Annotation mode
    if annotation_mode_checkbox.value:
        annotation_text = annotation_text_input.value or f"Point {len(annotations) + 1}"
        annotations.append({'x': clicked_x, 'y': clicked_y, 'text': annotation_text})
        update_plot(filtered_df_emb_list)
        with output_neighbors:
            output_neighbors.clear_output(wait=True)
            print(f"‚úÖ Added annotation: '{annotation_text}' at ({clicked_x:.2f}, {clicked_y:.2f})")
        return

    # Normal neighbor display mode
    with output_neighbors:
        output_neighbors.clear_output(wait=True)
        k_val = min(k_slider.value, len(filtered_df_emb_list))
        
        try:
            results = get_nearest(clicked_x, clicked_y, k=k_val)
            print(f"Displaying {len(results)} nearest neighbors for "
                  f"x={clicked_x:.2f}, y={clicked_y:.2f}")

            neighbor_names = [os.path.basename(res[0]) for res in results]
            show_neighbor_channels_grid(
                results, 
                k=k_val, 
                figsize=(10, 10), 
                neighbor_names=neighbor_names
            )
        except Exception as e:
            print(f"Error displaying neighbors: {e}")


def on_lasso_select(trace, points, selector):
    """Handle lasso/box selection on plot."""
    if not points.point_inds:
        print('no point IDs')
        return
    
    # Get real DataFrame indices
    sel_idx = [int(trace.ids[i]) for i in points.point_inds]
    selected_data = filtered_df_emb_list.loc[sel_idx].copy()

    # --- LABELING MODE ---
    if annotation_mode_checkbox.value:
        label = annotation_text_input.value.strip()
        if not label:
            label = f"manual_{df_emb_list['label_manual'].notna().sum() + 1}"

        # Assign label to both dataframes
        df_emb_list.loc[sel_idx, 'label_manual'] = label
        filtered_df_emb_list.loc[sel_idx, 'label_manual'] = label

        with output_selection_stats:
            output_selection_stats.clear_output()
            print(f"Labeled {len(selected_data)} cells ‚Üí **{label}**")

        update_plot(filtered_df_emb_list)
        return
    
    
    with output_selection_stats:
        output_selection_stats.clear_output(wait=True)
        
        print(f"üìä Selected {len(selected_data)} cells\n")
        
        # Phenotype distribution
        print("üî¨ Phenotype distribution:")
        phenotype_counts = selected_data['phenotype'].value_counts()
        for pheno, count in phenotype_counts.items():
            pct = 100 * count / len(selected_data)
            print(f"  {pheno}: {count} ({pct:.1f}%)")
        
        # Track information
        if 'track_id' in selected_data.columns:
            n_tracks = selected_data['track_id'].nunique()
            print(f"\nüõ§Ô∏è Tracks represented: {n_tracks}")
            
            # Show most common tracks
            track_counts = selected_data['track_id'].value_counts().head(5)
            print("  Top tracks:")
            for track, count in track_counts.items():
                print(f"    Track {track}: {count} cells")
        
        # Frame/time information
        if 't' in selected_data.columns:
            print(f"\n‚è±Ô∏è Frame range: {selected_data['t'].min()} - {selected_data['t'].max()}")
            print(f"  Mean frame: {selected_data['t'].mean():.1f}")
        
        # UMAP coordinate stats
        print(f"\nüìç UMAP coordinates:")
        print(f"  X range: [{selected_data['umap_1'].min():.2f}, {selected_data['umap_1'].max():.2f}]")
        print(f"  Y range: [{selected_data['umap_2'].min():.2f}, {selected_data['umap_2'].max():.2f}]")
        
        # Export button
        export_button = widgets.Button(
            description='Export selection to CSV',
            button_style='info',
            icon='download'
        )
        
        def export_selection(b):
            filename = f"umap_selection_{len(selected_data)}_cells.csv"
            selected_data.to_csv(filename, index=False)
            print(f"\n‚úÖ Exported to {filename}")
        
        export_button.on_click(export_selection)
        display(export_button)


def on_track_input_change(change):
    """Handle manual track ID input."""
    input_val = change.new.strip()
    
    if not input_val:
        return
    
    try:
        # Parse comma-separated track IDs
        track_ids = [int(x.strip()) for x in input_val.split(',')]
        
        # Validate tracks exist
        valid_tracks = [t for t in track_ids if t in all_track_ids]
        
        if valid_tracks:
            global selected_tracks
            selected_tracks = valid_tracks
            track_list_display.value = f"Selected: {', '.join(map(str, selected_tracks))}"
            update_plot(filtered_df_emb_list)
        else:
            track_list_display.value = "‚ö†Ô∏è No valid tracks found"
    
    except ValueError:
        track_list_display.value = "‚ö†Ô∏è Invalid input (use comma-separated numbers)"


def on_track_add_button(b):
    """Add track from dropdown to selection."""
    global selected_tracks
    
    track_val = track_id_dropdown.value
    if track_val and track_val not in selected_tracks:
        selected_tracks.append(track_val)
        selected_tracks.sort()
        track_list_display.value = f"Selected: {', '.join(map(str, selected_tracks))}"
        update_plot(filtered_df_emb_list)


def on_track_clear_button(b):
    """Clear all selected tracks."""
    global selected_tracks
    selected_tracks = []
    track_list_display.value = "Selected: None"
    track_input.value = ""
    update_plot(filtered_df_emb_list)


def on_phenotype_filter_change(change):
    """Handle phenotype multi-select filter change."""
    global filtered_df_emb_list
    
    selected_phenotypes = change.new
    
    if len(selected_phenotypes) == 0:
        filtered_df_emb_list = df_emb_list.copy()
    else:
        filtered_df_emb_list = df_emb_list[
            df_emb_list['phenotype'].isin(selected_phenotypes)
        ].copy()
    
    update_neighbors_model(filtered_df_emb_list)
    update_plot(filtered_df_emb_list)


def on_visual_option_change(change):
    """Handle changes to visual options (checkboxes)."""
    update_plot(filtered_df_emb_list)


def on_point_size_change(b, size):
    """Update marker size for all traces."""
    for trace in fig.data:
        if hasattr(trace, 'marker') and hasattr(trace.marker, 'size'):
            trace.marker.size = size


def on_opacity_change(change):
    global point_opacity
    point_opacity = change['new']
    update_plot(filtered_df_emb_list)
    

def on_clear_annotations(b):
    global annotations
    df_emb_list['label_manual'] = pd.NA
    filtered_df_emb_list['label_manual'] = pd.NA
    annotations = []
    update_plot(filtered_df_emb_list)
    with output_neighbors:
        output_neighbors.clear_output()
        print("Cleared all labels and annotations")


def on_reset_filters(b):
    """Reset all filters to default state."""
    global filtered_df_emb_list, selected_tracks, annotations
    
    selected_tracks = []
    annotations = []
    track_id_dropdown.value = all_track_ids[0]
    track_input.value = ""
    track_list_display.value = "Selected: None"
    track_filter_checkbox.value = False
    track_connect_checkbox.value = False
    phenotype_filter.value = []
    show_density_checkbox.value = False
    show_hulls_checkbox.value = False
    annotation_mode_checkbox.value = False
    
    filtered_df_emb_list = df_emb_list.copy()
    
    update_neighbors_model(filtered_df_emb_list)
    update_plot(filtered_df_emb_list)
    output_neighbors.clear_output()
    output_selection_stats.clear_output()
    coords_display.value = ''


def set_dragmode(mode):
    """Update figure dragmode and button styles."""
    fig.layout.dragmode = mode
    dragmode_pan.button_style = 'info' if mode == 'pan' else ''
    dragmode_select.button_style = 'info' if mode == 'select' else ''
    dragmode_lasso.button_style = 'info' if mode == 'lasso' else ''


# ============================================================================
# WIDGETS
# ============================================================================

# Output widgets
output_neighbors = widgets.Output()
output_selection_stats = widgets.Output()

# Controls
k_slider = widgets.IntSlider(
    value=5, min=1, max=20, step=1,
    description='Neighbors (k):',
    continuous_update=False,
    layout=Layout(width='300px')
)

coords_display = widgets.Text(
    value='', 
    placeholder='Click on a point',
    description='Clicked:',
    disabled=True,
    layout=Layout(width='400px')
)

# Track selection controls
track_id_dropdown = widgets.Dropdown(
    options=all_track_ids,
    value=all_track_ids[0],
    description='Add Track:',
    layout=Layout(width='150px')
)

track_input = widgets.Text(
    value='',
    placeholder='e.g., 1,5,12',
    description='Track IDs:',
    layout=Layout(width='200px')
)

track_add_button = widgets.Button(
    description='Add',
    button_style='success',
    layout=Layout(width='60px')
)

track_clear_button = widgets.Button(
    description='Clear',
    button_style='danger',
    layout=Layout(width='60px')
)

track_list_display = widgets.HTML(
    value="Selected: None",
    layout=Layout(width='400px')
)

track_filter_checkbox = widgets.Checkbox(
    value=False,
    description='Show only selected tracks',
    indent=False
)

track_connect_checkbox = widgets.Checkbox(
    value=False,
    description='Show trajectories',
    indent=False
)

trajectory_color_mode = widgets.Dropdown(
    options=['Time', 'Phenotype', 'Track'],
    value='Time',
    description='Color by:',
    layout=Layout(width='180px')
)

# Phenotype filter
phenotype_filter = widgets.SelectMultiple(
    options=sorted(df_emb_list['phenotype'].unique()),
    value=[],
    description='Phenotypes:',
    layout=Layout(width='200px', height='100px')
)

# Visual options
show_density_checkbox = widgets.Checkbox(
    value=False,
    description='Show density contours',
    indent=False
)

per_phenotype_density_checkbox = widgets.Checkbox(
    value=False,
    description='Density per phenotype',
    indent=False,
    layout=Layout(width='210px')
)

show_hulls_checkbox = widgets.Checkbox(
    value=False,
    description='Show convex hulls',
    indent=False
)

# Annotation mode
annotation_mode_checkbox = widgets.Checkbox(
    value=False,
    description='Annotation mode',
    indent=False
)

annotation_text_input = widgets.Text(
    value='',
    placeholder='Annotation text...',
    description='Text:',
    layout=Layout(width='250px')
)

use_manual_labels_checkbox = widgets.Checkbox(
    value=True,
    description='Color by manual labels',
    indent=False,
    layout=Layout(width='220px')
)
show_only_labeled_checkbox = widgets.Checkbox(
    value=False,
    description='Show only labeled cells',
    indent=False,
    layout=Layout(width='220px')
)

clear_annotations_button = widgets.Button(
    description='Clear Labels & Annotations',
    button_style='warning',
    layout=Layout(width='auto')
)

# Info display
info_text = widgets.HTML(
    value=f"üìä Displaying {len(df_emb_list):,} cells | "
          f"{df_emb_list['phenotype'].nunique()} phenotypes",
    layout=Layout(width='600px')
)

# Point size buttons
point_sizes = [2, 4, 6, 8]
size_buttons = []
for size in point_sizes:
    button = widgets.Button(
        description=f'Size {size}', 
        layout=Layout(width='60px')
    )
    button.on_click(lambda b, s=size: on_point_size_change(b, s))
    size_buttons.append(button)

point_opacity_slider = widgets.FloatSlider(
    value=0.7,          # default = your current opacity
    min=0.0,
    max=1.0,
    step=0.05,
    description='Point opacity:',
    continuous_update=False,
    readout=True,
    readout_format='.2f',
    layout=Layout(width='260px')
)

# Dragmode buttons
dragmode_pan = widgets.Button(description='Pan', button_style='', layout=Layout(width='70px'))
dragmode_select = widgets.Button(description='Select', button_style='info', layout=Layout(width='70px'))
dragmode_lasso = widgets.Button(description='Lasso', button_style='', layout=Layout(width='70px'))

dragmode_pan.on_click(lambda b: set_dragmode('pan'))
dragmode_select.on_click(lambda b: set_dragmode('select'))
dragmode_lasso.on_click(lambda b: set_dragmode('lasso'))

# Reset button
reset_button = widgets.Button(
    description='Reset All',
    button_style='danger',
    layout=Layout(width='auto')
)

# ============================================================================
# PLOT INITIALIZATION
# ============================================================================

fig = go.FigureWidget()
update_plot(filtered_df_emb_list)

fig.update_layout(
    title="scDINO UMAP Embedding - Interactive Explorer",
    width=900,
    height=700,
    showlegend=True,
    legend=dict(
        orientation="v",
        yanchor="top",
        y=1,
        xanchor="left",
        x=1.02
    ),
    hovermode='closest',
    dragmode='select'
)

# ============================================================================
# ATTACH EVENT HANDLERS
# ============================================================================

track_input.observe(on_track_input_change, names='value')
track_add_button.on_click(on_track_add_button)
track_clear_button.on_click(on_track_clear_button)
track_filter_checkbox.observe(on_visual_option_change, names='value')
track_connect_checkbox.observe(on_visual_option_change, names='value')
trajectory_color_mode.observe(on_visual_option_change, names='value')
phenotype_filter.observe(on_phenotype_filter_change, names='value')
show_density_checkbox.observe(on_visual_option_change, names='value')
per_phenotype_density_checkbox.observe(on_visual_option_change, names='value')
show_hulls_checkbox.observe(on_visual_option_change, names='value')
annotation_mode_checkbox.observe(on_visual_option_change, names='value')
use_manual_labels_checkbox.observe(on_visual_option_change, names='value')
show_only_labeled_checkbox.observe(on_visual_option_change, names='value')
clear_annotations_button.on_click(on_clear_annotations)
reset_button.on_click(on_reset_filters)
point_opacity_slider.observe(on_opacity_change, names='value')

# ============================================================================
# DISPLAY LAYOUT
# ============================================================================

display(VBox([
    # Header
    HBox([info_text]),
    
    # Main controls row
    HBox([coords_display, k_slider]),
    
    # Track selection panel
    VBox([
        widgets.HTML(value="<b>üõ§Ô∏è Track Selection:</b>"),
        HBox([track_id_dropdown, track_add_button, track_input, track_clear_button]),
        track_list_display,
        HBox([track_filter_checkbox, track_connect_checkbox]),
        HBox([widgets.HTML(value="<b>Trajectory coloring:</b>", layout=Layout(width='120px')), 
              trajectory_color_mode])
    ], layout=Layout(border='1px solid #ddd', padding='10px', margin='5px')),
    
    # Filter and visual options panel
    HBox([
        # Phenotype filter
        VBox([
            widgets.HTML(value="<b>üî¨ Phenotype Filter:</b>"),
            phenotype_filter
        ], layout=Layout(padding='5px')),
        
        # Visual options
        VBox([
            widgets.HTML(value="<b>üé® Visual Options:</b>"),
            show_density_checkbox,
            per_phenotype_density_checkbox,
            show_hulls_checkbox,
            widgets.HTML(value="<b>Point Size:</b>"),
            HBox(size_buttons),
            widgets.HTML(value="<b>Point Opacity:</b>"),
            point_opacity_slider,
            widgets.HTML(value="<b>Interaction Mode:</b>"),
            HBox([dragmode_pan, dragmode_select, dragmode_lasso])
        ], layout=Layout(padding='5px')),
        
        # Annotation panel
        VBox([
            widgets.HTML(value="<b>üìù Annotations:</b>"),
            annotation_mode_checkbox,
            annotation_text_input,
            widgets.HTML(value="<br><b>View Options:</b>"),
            use_manual_labels_checkbox,
            show_only_labeled_checkbox,
            widgets.HTML(value="<br><b>Actions:</b>"),
            clear_annotations_button,
            widgets.HTML(value="<br>"),
            reset_button
        ], layout=Layout(padding='5px'))
    ]),
    
    # Plot
    fig,
    
    # Output panels
    HBox([
        VBox([
            widgets.HTML(value="<b>üë• Clicked Point Neighbors:</b>"),
            output_neighbors
        ], layout=Layout(width='70%')),
        VBox([
            widgets.HTML(value="<b>üìä Selection Statistics:</b>"),
            output_selection_stats
        ], layout=Layout(width='30%'))
    ])
], layout=Layout(padding='10px')))

Total dataset size: 295,414 cells
Unique phenotypes: 2
Unique tracks: 2751


VBox(children=(HBox(children=(HTML(value='üìä Displaying 295,414 cells | 2 phenotypes | 2751 tracks', layout=Lay‚Ä¶