# Endocytosis

<img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/end1.gif" height="200">
<img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/end2.gif" height="200">

TIRF can visualize the formation of clathrin-coated pits, recruitment of adaptor proteins, and vesicle scission from the plasma membrane.

Since only the membrane-proximal region is illuminated, once vesicles move away from the membrane (into the cytoplasm), they disappear from the TIRF field, indicating internalization.



<img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/shape_header.svg" height="350">

SHAPE is an open-source framework for processing TIRF-SIM time-lapse data. It enables reconstruction, detection, linking, and analysis of plasma membrane processes such as clathrin-mediated endocytosis and exocytosis. With its modular design and automated workflow, SHAPE enhances tracking accuracy, reduces manual intervention, and supports high-throughput analysis of vesicle transport and productivity.

Explore reconstruction, object detection, linking, and analysis in this tutorial to assess your super-resolution data.

For more information, visit [GitHub repository link].

**📋Instructions**

1. 🛠️ Install Dependencies -
Automatically install requirements and download test data.
1. 🗂️ Load Input Data - Use the provided test data or upload your own dataset.

2. ⚙️Set Parameters -
Adjust parameters in the provided setup cell or use defaults.

3. 🧩 Reconstruct TIRF-SIM Images -
Combine 9 low-resolution images into a super-resolution reconstruction with preprocessing to reduce noise and artifacts.

4. 🔍 Detect Objects -
Use neural networks for object localization and shape extraction, optimized for robust segmentation.

5. 🔗 Linking -
Connect detections across frames by optimizing spatial and shape-based matching for globally consistent trajectories.

6. 📈 Analyze -
Analyze complete object trajectories to test hypotheses and derive biological insights.

💡 Notes:
* If possible, enable GPU in Google Colab by selecting Runtime > Change runtime type, choosing GPU, and clicking Save.

* If using large files, you might run out of RAM. Consider using smaller data or upgrading to Colab Pro for more memory and faster GPU access.

Start each section below by pressing <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px">.


In [None]:
#@title 🛠️ Install Dependencies

%autosave 300
import requests
import zipfile
import os
import sys
import subprocess
import json
import shutil
import random

def download_and_unzip(url, extract_to, chain_path):
    """Downloads and unzips a file if not already extracted."""
    if os.path.exists(extract_to):
        print(f"The directory '{extract_to}' already exists. Skipping download and extraction.")
        return
    local_zip_file = url.split("/")[-1]
    response = requests.get(url, stream=True, verify=chain_path)
    with open(local_zip_file, 'wb') as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)
    os.makedirs(extract_to, exist_ok=True)
    with zipfile.ZipFile(local_zip_file, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    os.remove(local_zip_file)
    print(f"Downloaded and extracted to '{extract_to}'.")

def install_packages(packages):
    """Installs the required packages if not already installed."""
    for package in packages:
        try:
            __import__(package)
            print(f"'{package}' is already installed. Skipping.")
        except ImportError:
            print(f"Installing '{package}'...")
            try:
                subprocess.run(
                    ["pip", "install", package],
                    check=True,
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE,
                )
                print(f"'{package}' installed successfully.")
            except subprocess.CalledProcessError as e:
                print(f"An error occurred while installing '{package}': {e.stderr.decode().strip()}")


# Download SSL certificate
chain_path = '/content/chain-harica-cross.pem'
r = requests.get(
    'https://pki.cesnet.cz/_media/certs/chain-harica-rsa-ov-crosssigned-root.pem',
    timeout=10, stream=True
)
r.raise_for_status()
with open(chain_path, 'wb') as f:
    f.write(r.content)

# Download test data
zip_url = "https://shape.utia.cas.cz/files/endocytosis/shape2.0_107.zip"
extract_directory = "/content/shape"
download_and_unzip(zip_url, extract_directory, chain_path)
os.chdir(extract_directory)
module_path = os.path.abspath(os.path.join('src/'))
if module_path not in sys.path:
    sys.path.append(module_path)

required_packages = ["ortools", "czifile", "mrcfile", "nd2"]
install_packages(required_packages)

import torch
import numpy as np
import pandas as pd
from google.colab import drive

from scipy import spatial

from preprocessing import subtract_background, fade_border, deconvolve_richardson_lucy
from parameter_estimation import estimate_phase_modulation, estimate_phase_offset, estimate_integer_shift, optimizer_shift_v2, compute_minimum_distance
from reconstruction import separate_components, run_reconstruction
from otf import OTF
from gpu_reconstruction import Reconstruction, AcquisitionParameters, ReconstructionParameters
from unet import UNet
from detection import generate_ccp_detections
from linking import LinkingGraph, UntanglingGraph, print_solver_status
from jupyter_utils import *

from ipywidgets import widgets, VBox, Output

from IPython.display import display
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML, clear_output

from IPython.display import Javascript
from google.colab import widgets as colab_widgets

from google.colab import files
import time

def animate_50_frames(lr_images):
    """
    Create and return a FuncAnimation for the first 50 frames of lr_images.
    """
    # Take the first 50 frames
    sub_movie = lr_images[:50]

    fig, ax = plt.subplots()
    im = ax.imshow(sub_movie[0], cmap='gray', animated=True)

    def update(frame):
        im.set_array(sub_movie[frame])
        return [im]

    ani = animation.FuncAnimation(
        fig,
        update,
        frames=sub_movie.shape[0],
        interval=100,
        blit=True
    )
    plt.close(fig)
    return HTML(ani.to_jshtml())

html_code_wait = """
    <div id="loading-msg"">
      <br /><br />
      <b>
        <span style="
          display: inline-block;
          animation: flipPause 2s ease infinite;
        ">⏳</span>
        Please wait...
      </b>
    </div>

    <style>
    @keyframes flipPause {
      0%   { transform: rotate(0deg); }
      40%  { transform: rotate(180deg); }
      50%  { transform: rotate(180deg); }
      90% { transform: rotate(360deg); }
      100% { transform: rotate(360deg); }
    }
    </style>
    """

html_code_clock = """
    <div id="loading-msg"">
      <br /><br />
      <b>
        <span style="
          display: inline-block;
          animation: flipPause 2s ease infinite;
        ">⏳</span>
        Loading the animation, please wait...
      </b>
    </div>

    <style>
    @keyframes flipPause {
      0%   { transform: rotate(0deg); }
      40%  { transform: rotate(180deg); }
      50%  { transform: rotate(180deg); }
      90% { transform: rotate(360deg); }
      100% { transform: rotate(360deg); }
    }
    </style>
    """

html_code_reconstruction = """
<div id="loading-msg">
  <br /><br />
  <b>
    <span style="
      display: inline-block;
      animation: flipPause 2s ease infinite;
    ">⏳</span>
    Preparing data for reconstruction, please wait...
  </b>
</div>

<style>
@keyframes flipPause {
  0%   { transform: rotate(0deg); }
  40%  { transform: rotate(180deg); }
  50%  { transform: rotate(180deg); }
  90%  { transform: rotate(360deg); }
  100% { transform: rotate(360deg); }
}
</style>
"""

html_code_detection = """
<div id="loading-msg">
  <br /><br />
  <b>
    <span style="
      display: inline-block;
      animation: flipPause 2s ease infinite;
    ">⏳</span>
    Preparing data for detection, please wait...
  </b>
</div>

<style>
@keyframes flipPause {
  0%   { transform: rotate(0deg); }
  40%  { transform: rotate(180deg); }
  50%  { transform: rotate(180deg); }
  90%  { transform: rotate(360deg); }
  100% { transform: rotate(360deg); }
}
</style>
"""

html_code_linking = """
<div id="loading-msg">
  <br /><br />
  <b>
    <span style="
      display: inline-block;
      animation: flipPause 2s ease infinite;
    ">⏳</span>
    Preparing data for linking, please wait...
  </b>
</div>

<style>
@keyframes flipPause {
  0%   { transform: rotate(0deg); }
  40%  { transform: rotate(180deg); }
  50%  { transform: rotate(180deg); }
  90%  { transform: rotate(360deg); }
  100% { transform: rotate(360deg); }
}
</style>
"""

html_code_linking2 = """
<div id="loading-msg">
  <br /><br />
  <b>
    <span style="
      display: inline-block;
      animation: flipPause 2s ease infinite;
    ">⏳</span>
    Linking, please wait...
  </b>
</div>

<style>
@keyframes flipPause {
  0%   { transform: rotate(0deg); }
  40%  { transform: rotate(180deg); }
  50%  { transform: rotate(180deg); }
  90%  { transform: rotate(360deg); }
  100% { transform: rotate(360deg); }
}
</style>
"""

print("✅ You are all set!")

# Endocytosis: single-channel movie

Continue  by pressing <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px"> below

In [None]:
#@title 🗂️ Load Input Data
import os, zipfile, json
import numpy as np
import pandas as pd
from IPython.display import display
import ipywidgets as widgets
from ipyfilechooser import FileChooser
from google.colab import drive

# Global memory variables
filename = None          # .dv file path
mask_filename = None     # mask .tif file path
parameters = {}          # loaded JSON dict
param_file_used = None   # JSON used path

DEFAULT_PARAM_FILE = "data/Oxford/20240703/default_parameters.json"

# 1️⃣ Data Mode Radio Buttons
data_mode = widgets.RadioButtons(
    options=[('Use Test Data','test'),('Use Own Data (Drive)','drive'),('Precomputed Data (Zip)','zip')],
    value='test', description='Data Mode:'
)

# 2️⃣ Test Data Dropdown
test_data_dropdown = widgets.Dropdown(
    options=[
        'data/Oxford/20240703/SHSY5Y_RUSHLAMP_CLCSNAP_107_subset.dv'
    ],
    value='data/Oxford/20240703/SHSY5Y_RUSHLAMP_CLCSNAP_107_subset.dv',
    description='Test Data:'
)
test_box = widgets.VBox([test_data_dropdown])

# 3️⃣ File Choosers (hidden initially)
drive_chooser_dv = FileChooser('.', title="Select .dv file", show_hidden=False)
drive_chooser_dv.filter_pattern = ['*.dv']
drive_chooser_dv.layout.display = 'none'

drive_chooser_mask = FileChooser('.', title="Select mask .tif file", show_hidden=False)
drive_chooser_mask.filter_pattern = ['*.tif']
drive_chooser_mask.layout.display = 'none'

zip_chooser = FileChooser('.', title="Select data .zip file", show_hidden=False)
zip_chooser.filter_pattern = ['*.zip']
zip_chooser.layout.display = 'none'

# 4️⃣ Layout Containers & Outputs
input_container = widgets.VBox()
status_out = widgets.Output()
final_out = widgets.Output()



# 5️⃣ Handle Data Mode Changes
def on_mode_change(change):
    if change['name']!='value': return
    mode = change['new']
    status_out.clear_output(); final_out.clear_output()
    # hide all
    test_box.layout.display='none'; drive_chooser_dv.layout.display='none'
    drive_chooser_mask.layout.display='none'; zip_chooser.layout.display='none'
    # mount if needed
    if mode in ('drive','zip'):
        if not os.path.isdir('/content/drive/MyDrive'):
            with status_out:
                print("🔄 Mounting Google Drive...")
            drive.mount('/content/drive')
            with status_out:
                print("✅ Google Drive mounted.")
        base = '/content/drive/MyDrive'
        drive_chooser_dv.reset(base); drive_chooser_mask.reset(base); zip_chooser.reset(base)
    # show
    if mode=='test':
        test_box.layout.display=None; input_container.children=[test_box]
    elif mode=='drive':
        drive_chooser_dv.layout.display=None; drive_chooser_mask.layout.display=None
        input_container.children=[drive_chooser_dv,drive_chooser_mask]
    else:
        zip_chooser.layout.display=None; input_container.children=[zip_chooser]

data_mode.observe(on_mode_change,names='value')
input_container.children=[test_box]

# 6️⃣ Finalize & Validate
def finalize(_):
    global filename,mask_filename,parameters,param_file_used
    with final_out:
        final_out.clear_output(); mode=data_mode.value
        # TEST
        if mode=='test':
            dv=test_data_dropdown.value
            if not os.path.isfile(dv): print(f"⚠️ Test .dv not found: {dv}"); return
            filename=dv; param_file_used=DEFAULT_PARAM_FILE if os.path.isfile(DEFAULT_PARAM_FILE) else None
            if param_file_used: parameters=json.load(open(param_file_used))
            mask_filename=filename.replace('subset.dv','MASK.tif')
            print("✅ Loaded Test Data"); print(f".dv: {filename}"); print(f"Mask: {mask_filename}"); print(f"Params: {param_file_used}"); return
        # DRIVE
        if mode == 'drive':
            try:
                dv_base = drive_chooser_dv.selected_path
                if dv_base is None:
                    raise ValueError("No .dv file selected")
                if os.path.isdir(dv_base):
                    name = drive_chooser_dv.selected_filename
                    dv = os.path.join(dv_base, name) if name else dv_base
                else:
                    dv = dv_base
                if not os.path.isfile(dv):
                    print(f"⚠️ .dv file not found: {dv}")
                    return
                filename = dv
            except Exception as e:
                print("⚠️ Please select a valid .dv file and click 'Select'.")
                return

            try:
                m_base = drive_chooser_mask.selected_path
                if m_base is None:
                    raise ValueError("No mask file selected")
                if os.path.isdir(m_base):
                    mname = drive_chooser_mask.selected_filename
                    mask = os.path.join(m_base, mname) if mname else m_base
                else:
                    mask = m_base
                if not os.path.isfile(mask):
                    print(f"⚠️ Mask file not found: {mask}")
                    return
                mask_filename = mask
            except Exception as e:
                print("⚠️ Please select a valid mask .tif file and click 'Select'.")
                return

            param_file_used = DEFAULT_PARAM_FILE if os.path.isfile(DEFAULT_PARAM_FILE) else None
            if param_file_used:
                parameters = json.load(open(param_file_used))

            print("✅ Loaded Drive Data")
            print(f".dv: {filename}")
            print(f"Mask: {mask_filename}")
            print(f"Params: {param_file_used}")
            return

        # ZIP
        if mode == 'zip':
            try:
                zip_base = zip_chooser.selected_path
                if zip_base is None:
                    raise ValueError("No zip file selected")
                if os.path.isdir(zip_base):
                    zname = zip_chooser.selected_filename
                    zp = os.path.join(zip_base, zname) if zname else zip_base
                else:
                    zp = zip_base
                if not os.path.isfile(zp):
                    print(f"⚠️ .zip not found: {zp}")
                    return
            except Exception as e:
                print("⚠️ Please select a valid .zip file and click 'Select'.")
                return

            with status_out:
                status_out.clear_output()
                print(f"⏳ Extracting {zp}...")

            extract_dir = 'precomputed_data'
            with zipfile.ZipFile(zp, 'r') as zf:
                zf.extractall(extract_dir)

            pth = os.path.join(extract_dir, 'parameters.json')
            if os.path.isfile(pth):
                parameters = json.load(open(pth))
                param_file_used = pth
            else:
                param_file_used = None

            filename = os.path.join(extract_dir, 'input.dv')
            mask_filename = os.path.join(extract_dir, 'mask_file.tif')

            print("✅ Loaded Precomputed Data")
            print(f"Params: {param_file_used}")
            print(f".dv: {filename}")
            print(f"Mask: {mask_filename}")



# 7️⃣ Display UI
def run_ui():
    print("Select Data Mode below:"); display(data_mode); display(input_container)
    btn=widgets.Button(description="Confirm Your Choices",button_style='success')
    btn.on_click(finalize); display(btn,status_out,final_out)
run_ui()

# Automatically finalize if running in 'test' mode
if data_mode.value == 'test':
    finalize(None)  # Simulate button click


Continue by pressing <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px"> below (it should turn into <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/stop1.png" height="25px">)

In [None]:
#@title ↳ Show input

if 'filename' not in globals() or not filename:
    print("⚠️ No data. Please load input DV file above!")
else:
    # Load the image data from the file
    lr_images = open_image_file(filename).astype(np.float64)

    # Show a "loading" message first


    display(HTML(html_code_clock))


    # Generate animation HTML for the first 50 frames
    anim_html = animate_50_frames(lr_images)

    # JavaScript snippet to *replace* the "loading" text with "Only first 50 frames are displayed."
    replace_loading_js = """
    <script>
    setTimeout(function(){
        var loadingDiv = document.getElementById("loading-msg");
        if (loadingDiv) {
            loadingDiv.innerHTML = '<br /><b>Only the first 50 frames are displayed.</b>';
        }
    }, 0);
    </script>
    """

    # Display the animation and then replace the "Loading..." message
    display(anim_html)
    display(HTML(replace_loading_js))


❗❗❗ Press <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px"> below whenever you change the input data.

In [None]:
#@title ⚙️ Setup Parameters

device_type = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device_type}")

# Check if 'filename' and 'parameters' exist in the current session
if 'filename' not in globals() or not filename:
    # If there's no filename, display a message and stop
    print("⚠️ No data. Please load input DV file above!")
else:
    # We have a filename, so let's use 'parameters' if it exists or default to {}
    if 'parameters' not in globals():
        parameters = {}
        print("⚠️ 'parameters' not found. Using empty defaults.")
    else:
        print(f"Data file to be processed: {filename}\n\n")

    # Extract sub-dictionaries with fallback defaults
    otf_params       = parameters.get("otf_parameters", {})
    recon_params     = parameters.get("reconstruction_parameters", {})
    input_params     = parameters.get("input_parameters", {})
    detector_params  = parameters.get("detector_parameters", {})
    linking_params   = parameters.get("linking_parameters", {})

    # ------------------------------------------
    # Update function for all widgets
    # ------------------------------------------
    def update_values(_=None):
        global image_size, na, px_size, wavelength, otf_curvature
        global wiener_parameter, apodization_cutoff, apodization_bend
        global background_threshold, border_fade, deconvolution_iterations, fit_exclude, batch_size
        global checkpoint_path
        global max_linking_distance, birth_death_cost, cls_cost_multiplier, gamma

        # OTF Parameters
        image_size        = image_size_widget.value
        na                = na_widget.value
        px_size           = px_size_widget.value
        wavelength        = wavelength_widget.value
        otf_curvature     = otf_curvature_widget.value

        # Reconstruction Parameters
        wiener_parameter  = wiener_parameter_widget.value
        apodization_cutoff= apodization_cutoff_widget.value
        apodization_bend  = apodization_bend_widget.value

        # Input Parameters (excluding filename)
        background_threshold    = background_threshold_widget.value
        border_fade             = border_fade_widget.value
        deconvolution_iterations= deconvolution_iterations_widget.value
        fit_exclude             = fit_exclude_widget.value
        batch_size              = batch_size_widget.value

        # Detector Parameters
        checkpoint_path         = checkpoint_path_widget.value

        # Linking Parameters
        max_linking_distance    = max_linking_distance_widget.value
        birth_death_cost        = birth_death_cost_widget.value
        cls_cost_multiplier     = cls_cost_multiplier_widget.value
        gamma                   = gamma_widget.value

    def attach_listeners(widget_list):
        for w in widget_list:
            w.observe(update_values, names="value")

    # ------------------------------------------
    # Define your widgets
    # ------------------------------------------
    label_width = '150px'
    input_width = '300px'

    # OTF widgets
    image_size_widget = widgets.IntText(
        value=otf_params.get("image_size", 512),
        description="Image Size:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    na_widget = widgets.FloatText(
        value=otf_params.get("na", 1.5),
        description="NA:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    px_size_widget = widgets.FloatText(
        value=otf_params.get("px_size", 0.0791),
        description="Pixel Size (µm):",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    wavelength_widget = widgets.IntText(
        value=otf_params.get("wavelength", 603),
        description="Wavelength (nm):",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    otf_curvature_widget = widgets.FloatText(
        value=otf_params.get("otf_curvature", 0.3),
        description="OTF Curvature:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )

    otf_ui = VBox([
        image_size_widget, na_widget, px_size_widget, wavelength_widget, otf_curvature_widget
    ])
    attach_listeners([image_size_widget, na_widget, px_size_widget, wavelength_widget, otf_curvature_widget])

    # Reconstruction widgets
    wiener_parameter_widget = widgets.FloatText(
        value=recon_params.get("wiener_parameter", 0.05),
        description="Wiener Param:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    apodization_cutoff_widget = widgets.FloatText(
        value=recon_params.get("apodization_cutoff", 2.0),
        description="Apod. Cutoff:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    apodization_bend_widget = widgets.FloatText(
        value=recon_params.get("apodization_bend", 0.9),
        description="Apod. Bend:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )

    reconstruction_ui = VBox([
        wiener_parameter_widget, apodization_cutoff_widget, apodization_bend_widget
    ])
    attach_listeners([wiener_parameter_widget, apodization_cutoff_widget, apodization_bend_widget])

    # Input widgets (excluding filename)
    background_threshold_widget = widgets.IntText(
        value=input_params.get("background_threshold", 100),
        description="Background Thr:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    border_fade_widget = widgets.IntText(
        value=input_params.get("border_fade", 15),
        description="Border Fade:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    deconvolution_iterations_widget = widgets.IntText(
        value=input_params.get("deconvolution_iterations", 10),
        description="Deconv. Iters:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    fit_exclude_widget = widgets.FloatText(
        value=input_params.get("fit_exclude", 0.75),
        description="Fit Exclude:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    batch_size_widget = widgets.IntText(
        value=input_params.get("batch_size", 2),
        description="Batch Size:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )

    input_ui = VBox([
        background_threshold_widget, border_fade_widget, deconvolution_iterations_widget,
        fit_exclude_widget, batch_size_widget
    ])
    attach_listeners([
        background_threshold_widget, border_fade_widget, deconvolution_iterations_widget,
        fit_exclude_widget, batch_size_widget
    ])

    # Detector widget
    checkpoint_path_widget = widgets.Text(
        value=detector_params.get("checkpoint_path", "data/Interim/checkpoints/ccp-detector-sandy-wildflower-269.pt"),
        description="Checkpoint Path:",
        style={'description_width': label_width},
        layout={'width': '700px'}
    )
    detector_ui = VBox([checkpoint_path_widget])
    attach_listeners([checkpoint_path_widget])

    # Linking widgets
    max_linking_distance_widget = widgets.FloatText(
        value=linking_params.get("max_linking_distance", 7.5),
        description="Max Link Dist:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    birth_death_cost_widget = widgets.FloatText(
        value=linking_params.get("birth_death_cost", 5),
        description="Birth/Death Cost:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    cls_cost_multiplier_widget = widgets.FloatText(
        value=linking_params.get("cls_cost_multiplier", 1),
        description="Cls Multiplier:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )
    gamma_widget = widgets.FloatText(
        value=linking_params.get("gamma", 10),
        description="Gamma:",
        style={'description_width': label_width},
        layout={'width': input_width}
    )

    linking_ui = VBox([
        max_linking_distance_widget, birth_death_cost_widget, cls_cost_multiplier_widget, gamma_widget
    ])
    attach_listeners([
        max_linking_distance_widget, birth_death_cost_widget,
        cls_cost_multiplier_widget, gamma_widget
    ])

    # ------------------------------------------
    # Display in Colab as tabbed widgets
    # ------------------------------------------
    def display_colab_widgets():
        tab = colab_widgets.TabBar([
            "OTF Parameters",
            "Reconstruction Parameters",
            "Input Parameters",
            "Detector Parameters",
            "Linking Parameters"
        ], location="top")

        with tab.output_to(0):
            display(otf_ui)
        with tab.output_to(1):
            display(reconstruction_ui)
        with tab.output_to(2):
            display(input_ui)
        with tab.output_to(3):
            display(detector_ui)
        with tab.output_to(4):
            display(linking_ui)

    # Initialize and show
    update_values()
    display_colab_widgets()

    print("\n⚙️ Setup loaded! You can adjust the parameters above.")
    print("📝 Any changes you make are saved instantly.")
    print("🔁 Re-running the cell will reset all parameters to default.")


## 🧩 Reconstruction

<img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/rec_head.svg" height="300">

Continue by pressing <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px"> below.

In [None]:
#@title Run Reconstruction

display(HTML(html_code_reconstruction))
time.sleep(0.1)

otf = OTF(na, wavelength, px_size, image_size, otf_curvature)
config = dict(na=na, wavelength=wavelength, px_size=px_size, wiener_parameter=wiener_parameter, apo_cutoff=apodization_cutoff, apo_bend=apodization_bend)
ap = AcquisitionParameters(na=na, wavelength=wavelength, px_size=px_size, image_size=image_size)
rp = ReconstructionParameters(wiener_parameter=wiener_parameter, apodization_cutoff=apodization_cutoff, apodization_bend=apodization_bend)

lr_images = open_image_file(filename).astype(np.float64)

# Common preprocessing for all images in the movie
lr_images = subtract_background(lr_images, background_threshold)
fade_border(lr_images, border_fade)

# Estimate parameters on the first frame (9 images)
rl_images = deconvolve_richardson_lucy(lr_images[:9], deconvolution_iterations, otf())
f_images = np.fft.fft2(rl_images)

components = separate_components(f_images)

# This step may be skipped if we already have approximate shifts from a previous reconstruction
minimum_distance = compute_minimum_distance(fit_exclude, image_size, px_size, na, wavelength)
approximate_shifts = [estimate_integer_shift(components[i * 3], components[i * 3 + 1], minimum_distance)
                      for i in range(3)]

shifts = [optimizer_shift_v2(approximate_shifts[i], components[i * 3], components[i * 3 + 1])
          for i in range(3)]

phase_offsets = [estimate_phase_offset(f_images[i * 3:(i + 1) * 3], shifts[i])
                 for i in range(3)]

# This step is generally not possible to do accurately so we just hardcode the values but the estimated values serve as a quality metric of the reconstruction
estimated_modulations = [estimate_phase_modulation(components[i * 3:(i + 1) * 3], shifts[i], otf)
                         for i in range(3)]

modulations = [1.0 for _ in range(3)]

# Single frame cpu reconstruction
fft_result, spatial_result = run_reconstruction(np.fft.fft2(lr_images[:9]), otf, shifts, phase_offsets, modulations, config)

# Whole movie gpu reconstruction
device = torch.device(device_type)

clear_output(wait=True)

r = Reconstruction(otf, shifts, phase_offsets, modulations, ap, rp, device)
result = r.reconstruct(lr_images, batch_size)

print("✅ Reconstruction finished!")

Continue by pressing <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px"> below (it should turn into <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/stop1.png" height="25px">)

In [None]:
#@title ↳ Show Reconstruction

display(HTML(html_code_clock))

fig, ax = plt.subplots()
im = ax.imshow(result[0], cmap='gray')

def animate(i):
    im.set_array(result[i])
    return [im]

ani = animation.FuncAnimation(fig, animate, frames=result.shape[0], interval=100, blit=True)
plt.close(fig)  # Prevent static image display

anim_html = HTML(ani.to_jshtml())

display(anim_html)

remove_loading_js = """
<script>
    // Remove the loading message after the animation HTML is inserted
    setTimeout(function(){
        var loadingDiv = document.getElementById("loading-msg");
        if (loadingDiv) {
            loadingDiv.remove();
        }
    }, 0);
</script>
"""

display(HTML(remove_loading_js))

## 🔍 Detection

<img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/det_head2.svg" height="190">

Continue by pressing <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px"> below.

In [None]:
#@title Run Detection
display(HTML(html_code_detection))
time.sleep(0.1)
model = UNet(depth=3, start_filters=16, up_mode='nearest')
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device_type),weights_only=False)
model.load_state_dict(checkpoint)
model.eval()
model.to(device);
clear_output(wait=True)
detections = generate_ccp_detections(model, device, result)
print("✅ Detection finished!")

Continue by pressing <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px"> below (it should turn into <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/stop1.png" height="25px">)

In [None]:
#@title ↳ Show Detection Positions

display(HTML(html_code_clock))

fig, ax = plt.subplots()
im = ax.imshow(result[0], cmap='gray')
d0 = detections[detections.frame == 0]
scat = ax.scatter(d0.x, d0.y, s=5, marker='+', c='red')

def animate(i):
    im.set_array(result[i])
    di = detections[detections.frame == i]
    if len(di) > 0:
        scat.set_offsets(list(zip(di.x.values, di.y.values)))
    else:
        scat.set_offsets([])
    return [im, scat]

ani = animation.FuncAnimation(fig, animate, frames=result.shape[0], interval=100, blit=True)
plt.close(fig)
anim_html = HTML(ani.to_jshtml())
display(anim_html)

remove_loading_js = """
<script>
    setTimeout(function(){
        var loadingDiv = document.getElementById("loading-msg");
        if (loadingDiv) {
            loadingDiv.remove();
        }
    }, 0);
</script>
"""

display(HTML(remove_loading_js))

## 🔗 Linking

<img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/lin_head.png" height="230">

Continue by pressing <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px"> below

In [None]:
#@title Run Linking

def distance_function(a: pd.DataFrame, b: pd.DataFrame):
    a_xy = np.array(a[['x', 'y']])
    b_xy = np.array(b[['x', 'y']])
    euclidian_dist = spatial.distance.cdist(a_xy, b_xy, metric='euclidean')

    a_cls = np.array(a['cls'])
    b_cls = np.array(b['cls'])
    cls_dist = np.square(a_cls[:, None] - b_cls[None, :])

    return euclidian_dist + cls_cost_multiplier * cls_dist

def birth_death_cost_function(row: pd.Series):
    return birth_death_cost + row['cls']

display(HTML(html_code_linking))
time.sleep(0.1)

linking_graph = LinkingGraph(detections, distance_function, max_linking_distance, birth_death_cost_function)
clear_output(wait=True)
print(f'{linking_graph.solver.NumConstraints()} constraints, {linking_graph.solver.NumVariables()} variables')
display(HTML(html_code_linking2))
time.sleep(0.1)
status = linking_graph.solve()
print_solver_status(status, linking_graph.solver)

tracklets = linking_graph.get_result()

untangling_graph = UntanglingGraph(tracklets, gamma)
clear_output(wait=True)
print(f'{untangling_graph.solver.NumConstraints()} constraints, {untangling_graph.solver.NumVariables()} variables')
display(HTML(html_code_linking2))
time.sleep(0.1)
status = untangling_graph.solve()
print_solver_status(status, untangling_graph.solver)

trajectories = untangling_graph.get_result(detections)

# Filtering to remove very short tracks
filtered_trajectories = trajectories.groupby('particle').filter(lambda t: t.frame.count() >= 6)
clear_output(wait=True)

print("✅ Linking finished!")

Continue by pressing <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px"> below (it should turn into <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/stop1.png" height="25px">)

In [None]:
#@title ↳ Show Linking

display(HTML(html_code_clock))

fig, ax = plt.subplots()
im = ax.imshow(result[0], cmap='gray')

# Configure tail length (how many past frames to show)
tail_length = 10

# Unique particles
particles = filtered_trajectories['particle'].unique()
colors = plt.cm.tab20(np.linspace(0, 1, len(particles)))
particle_color_map = dict(zip(particles, colors))

# Initialize line artists for each particle
line_artists = {pid: ax.plot([], [], color=particle_color_map[pid], lw=1)[0] for pid in particles}
dot = ax.scatter([], [], s=5, marker='o', c='red')

def animate(i):
    im.set_array(result[i])

    current = filtered_trajectories[
        (filtered_trajectories['frame'] >= i - tail_length) &
        (filtered_trajectories['frame'] <= i)
    ]

    # Update dot for current positions
    now = current[current['frame'] == i]
    dot.set_offsets(np.c_[now.x.values, now.y.values])

    # Update each trajectory's line
    for pid in particles:
        trail = current[current['particle'] == pid].sort_values('frame')
        if len(trail) >= 2:
            line = line_artists[pid]
            line.set_data(trail.x.values, trail.y.values)
            line.set_alpha(1.0)  # current full line
        else:
            line_artists[pid].set_data([], [])
            line_artists[pid].set_alpha(0.0)

    return [im, dot] + list(line_artists.values())

ani = animation.FuncAnimation(fig, animate, frames=result.shape[0], interval=100, blit=True)
plt.close(fig)
display(HTML(ani.to_jshtml()))

# Remove loading text
remove_loading_js = """
<script>
    setTimeout(function(){
        var loadingDiv = document.getElementById("loading-msg");
        if (loadingDiv) {
            loadingDiv.remove();
        }
    }, 0);
</script>
"""
display(HTML(remove_loading_js))

## 📈 Analysis



> *“All this effort is only worthwhile if it leads to new biological discoveries. Discoveries that could not have been made with any other tool.”*
Lothar, Royal Oak, circa July 2024




Continue by pressing <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px"> below.

In [None]:
#@title Productivity  ⚠️


In [None]:
#@title Relative frequency vs Lifetime

# Fixed frame rate
frame_rate = 3

n_frames = len(result)
mask = open_image_file(mask_filename).astype(np.float64)

def update_plot(frame_rate):
    lifetime_frames = (
        filtered_trajectories.groupby('particle')
        .filter(lambda t: (
            (t.cls > 0.7).any() and
            (0 < t.frame).all() and
            (t.frame < n_frames - 1).all() and
            (t.x > 15).all() and
            (t.x < 1024 - 15).all() and
            (t.y > 15).all() and
            (t.y < 1024 - 15).all() and
            mask[t.y.astype(int), t.x.astype(int)].all()
        ))
        .groupby('particle').frame.agg(lambda f: f.max() - f.min() + 1).values
    )

    lifetimes = lifetime_frames * frame_rate
    lifetime_weights = (n_frames - 2) / (n_frames - 1 - lifetime_frames)

    seconds_per_bin = 6
    bin_edges = np.arange(10, 201, seconds_per_bin)

    hist, _ = np.histogram(lifetimes, weights=lifetime_weights, bins=bin_edges, density=True)

    # Plot
    plt.figure(figsize=(8, 6))
    plt.plot((bin_edges[1:] + bin_edges[:-1]) / 2, hist)
    plt.xlabel('Lifetime [s]')
    plt.ylabel('Relative frequency')
    plt.xticks(np.arange(0, 201, 30))
    plt.xticks(bin_edges, minor=True)
    plt.tick_params(axis='x', which='minor', direction='in')
    plt.show()

# Run the plotting function once
update_plot(frame_rate)

## 📤 Export Results

You can save important outputs (and inputs) to your Google Drive or computer. Select which outputs to include, then click **Prepare ZIP.** You can change the randomly generated name of the ZIP file. After that, you can either download the ZIP file or save it to your Google Drive. You can load the files back here later.

Continue by pressing <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px"> below

In [None]:
#@title 📦 Create Results ZIP


adjectives = [
    'silent', 'purple', 'ancient', 'rapid', 'curious', 'brave', 'happy', 'sleepy',
    'bright', 'dark', 'fuzzy', 'eager', 'gentle', 'wild', 'quiet', 'fiery'
]

nouns = [
    'otter', 'moon', 'mountain', 'river', 'cloud', 'panther', 'eagle', 'forest',
    'breeze', 'ember', 'wave', 'meadow', 'shadow', 'falcon', 'comet', 'storm'
]

def generate_random_name():
    adj = random.choice(adjectives)
    noun = random.choice(nouns)
    number = random.randint(100, 999)
    return f"{adj}_{noun}_{number}"

readme_text = """SHAPE Export Package
====================

This ZIP file contains the exported results from your SHAPE analysis session.

Contents may include:
- parameters.json ............. Reconstruction and detection parameters
- input.dv .................... The input DV file used in the analysis
- mask_file.tif ............... The mask file used in the analysis
- reconstruction.tif ......... Time-lapse TIF image of the reconstructed sample
- detections.csv ............. Detected object positions (per frame)
- trajectories.csv ........... Tracked particle trajectories across time
- napari_show_detections.py .. Python script to view detections in napari
- napari_show_trajectories.py. Python script to view trajectories in napari

------------------------------------------------------------
📊 Viewing the Results in Napari (Python-based viewer)
------------------------------------------------------------

1. Make sure you have napari and required libraries installed:

   pip install napari[all] pandas scikit-image

2. To view **detections**, run the following in terminal or Python:

   python napari_show_detections.py

   → This will show the reconstructed image and overlay detections per time frame.

3. To view **trajectories**, run:

   python napari_show_trajectories.py

   → This will show tracked particles with trail lines over time.

Make sure the corresponding reconstruction.tif and result .csv files
are in the same directory as the Python scripts when running them.

------------------------------------------------------------
💬 Notes
------------------------------------------------------------
- Napari works best in a desktop environment.
- These scripts use matplotlib, pandas, and skimage.
- If you move the files elsewhere, keep related files together.

Have fun exploring your data!
"""

base      = os.path.splitext(os.path.basename(filename))[0]
# 1️⃣ Build the file‐selection UI

default_zip_name = f"{generate_random_name()}.zip"
zipname_field = widgets.Text(value=default_zip_name, description='ZIP Name:', layout={'width': '700px'})
input_cb    = widgets.Checkbox(value=True, description="Input Data (DV)")
mask_cb     = widgets.Checkbox(value=True, description="Mask File (TIFF)")
param_cb    = widgets.Checkbox(value=True, description="Parameters (JSON)")
recon_cb    = widgets.Checkbox(value=True, description="Reconstruction (TIFF)")
detect_cb   = widgets.Checkbox(value=True, description="Detections (CSV)")
tracks_cb   = widgets.Checkbox(value=True, description="Trajectories (CSV)")

prepare_btn = widgets.Button(description="📦 Prepare ZIP", button_style='primary')
output1     = widgets.Output()

ui1 = widgets.VBox([
    zipname_field,
    input_cb,
    mask_cb,
    param_cb,
    recon_cb,
    detect_cb,
    tracks_cb,
    prepare_btn,
    output1
])
display(ui1)

# This will hold the name of the ZIP once created
zip_name = None

# 2️⃣ Preparation callback
def on_prepare(_):
    global zip_name
    prepare_btn.disabled = True
    with output1:
        output1.clear_output()
        label    = HTML(html_code_wait+'<b>📦 Creating ZIP…</b>')
        display(label)
        # Gather files
        files_to_zip = []
        if input_cb.value:
            try: shutil.copy(filename, 'input.dv'); files_to_zip.append('input.dv')
            except: pass
        if mask_cb.value:
            try: shutil.copy(mask_filename, 'mask_file.tif'); files_to_zip.append('mask_file.tif')
            except: pass
        if param_cb.value:
            with open("parameters.json","w") as jf:
                json.dump(parameters, jf, indent=4)
            files_to_zip.append("parameters.json")
        if recon_cb.value and 'result' in globals():
            import tifffile
            tifffile.imwrite("reconstruction.tif", np.array(result, dtype=np.float32))
            files_to_zip.append("reconstruction.tif")
        if detect_cb.value and 'detections' in globals():
            df = detections if isinstance(detections, pd.DataFrame) else pd.DataFrame(detections)
            if 'cls' in df.columns:
                df = df.rename(columns={'cls':'shape_index'})
            df.to_csv("detections.csv", index=False)
            files_to_zip.append("detections.csv")
        if tracks_cb.value:
            key = 'filtered_trajectories' if 'filtered_trajectories' in globals() else \
                  'trajectories'          if 'trajectories'          in globals() else None
            if key:
                df = globals()[key]
                if 'cls' in df.columns:
                    df = df.rename(columns={'cls': 'shape_index'})
                df.to_csv("trajectories.csv", index=False)
                files_to_zip.append("trajectories.csv")
        for script in ['napari_show_detections.py','napari_show_trajectories.py']:
            src = os.path.join("scripts", script)
            if os.path.exists(src):
                shutil.copy(src, script)
                files_to_zip.append(script)
        # always include README
        with open("README.txt","w") as f:
            f.write(readme_text)
        files_to_zip.append("README.txt")

        if not files_to_zip:
            print("⚠️ No files selected to zip.")
            prepare_btn.disabled = False
            return

        # derive zip name
        zip_name  = zipname_field.value.strip()

        # show progress bar
        from ipywidgets import IntProgress, HTML as WHTML, VBox
        progress = IntProgress(min=0, max=len(files_to_zip), description='Zipping:')
        display(VBox([progress]))

        # write the archive
        with zipfile.ZipFile(zip_name, "w", compression=zipfile.ZIP_DEFLATED) as zf:
            for i, fname in enumerate(files_to_zip, start=1):
                zf.write(fname)
                progress.value = i

        # 3️⃣ Clear logs & show only success
        output1.clear_output()
        # Hide checkboxes and button
        zipname_field.layout.display = 'none'
        input_cb.layout.display = 'none'
        mask_cb.layout.display = 'none'
        param_cb.layout.display = 'none'
        recon_cb.layout.display = 'none'
        detect_cb.layout.display = 'none'
        tracks_cb.layout.display = 'none'
        prepare_btn.layout.display = 'none'
        display(widgets.HTML(f'<b>✅ {zip_name} is ready!</b>'))


prepare_btn.on_click(on_prepare)

❗ Click the <img src="https://shape.utia.cas.cz/files/endocytosis/colab_images/play.png" height="20px"> button below — and click it again each time you change the destination.

In [None]:
#@title ⬇️ Download or Save to Drive

import os, shutil
from IPython.display import display
import ipywidgets as widgets
from ipyfilechooser import FileChooser

# Colab helpers
from google.colab import files as colab_files, drive
from ipywidgets import IntProgress, HTML as WHTML, VBox

# 1️⃣ UI widgets
dest_radio  = widgets.RadioButtons(
    options=[('Download locally','local'),
             ('Save to Drive','drive')],
    value='local',
    description='Destination:'
)

drive_chooser = FileChooser('.', title='Pick a Drive folder', show_hidden=False)
drive_chooser.show_only_dirs = True
drive_chooser.hide_file     = True
drive_chooser.layout.display = 'none'

execute_btn = widgets.Button(description="▶️ Execute", button_style='primary')
output2     = widgets.Output()

ui2 = widgets.VBox([
    dest_radio,
    drive_chooser,
    execute_btn,
    output2
])
display(ui2)

# 2️⃣ Show/hide chooser & mount on-demand
def on_dest_change2(change):
    if change['new']=='drive':
        if not os.path.isdir('/content/drive/MyDrive'):
            drive.mount('/content/drive')
        drive_chooser.reset('/content/drive/MyDrive')
        drive_chooser.layout.display = 'block'
    else:
        drive_chooser.layout.display = 'none'

dest_radio.observe(on_dest_change2, names='value')

# 3️⃣ Execution callback with single download + Drive‐copy progress
download_triggered = False

def on_execute(_):
    global download_triggered
    execute_btn.disabled = True  # prevent re-clicks
    dest_radio.layout.display = 'none'
    execute_btn.layout.display = 'none'
    drive_chooser.layout.display = 'none'

    with output2:
        output2.clear_output()

        if 'zip_name' not in globals() or not os.path.exists(zip_name):
            print("⚠️ ZIP not found—please run the Prepare cell first.")
            return

        if dest_radio.value == 'local':
            if not download_triggered:
                colab_files.download(zip_name)
                download_triggered = True
            else:
                print("✅ Download already triggered.")
        else:
            # Save to Drive with a progress bar
            target = drive_chooser.selected_path
            if not target or not os.path.isdir(target):
                print("⚠️ Please pick a valid Drive folder above.")
                return

            dst = os.path.join(target, zip_name)
            if os.path.exists(dst):
                display(widgets.HTML(f'✅ Already saved at:<br><code>{dst}</code>'))
                return

            # chunked copy
            total_size = os.path.getsize(zip_name)
            progress   = IntProgress(min=0, max=total_size, description='Saving:')
            label      = WHTML('<b>⏳ Saving to Drive…</b>')
            display(VBox([label, progress]))

            with open(zip_name, 'rb') as src, open(dst, 'wb') as dest:
                chunk_size = 1_024*1_024  # 1 MB
                written = 0
                while True:
                    chunk = src.read(chunk_size)
                    if not chunk:
                        break
                    dest.write(chunk)
                    written += len(chunk)
                    progress.value = written

            progress.close()
            label.value = '<b>✅ Saved to Drive!</b><br /><i>(It may take a minute for the file to appear in Drive.)</i>'

execute_btn.on_click(on_execute)