# Example: Full Pipeline - From Synthetic Data Tuning to Model Evaluation

This notebook demonstrates the complete SynthMT pipeline:
1. **Tune synthetic data** to match real microscopy images using embedding-based optimization
2. **Generate synthetic samples** with ground-truth masks
3. **Tune SAM3Text hyperparameters** on the generated synthetic data
4. **Evaluate segmentation and downstream metrics**

This end-to-end workflow shows how to create domain-specific synthetic data and use it to optimize a foundation model for microtubule segmentationâ€”without requiring manual annotations.

![Full Pipeline Overview](images/data_gen_overview.png)

In [None]:
import os
from functools import partial
from pathlib import Path

import numpy as np
import optuna
import matplotlib.pyplot as plt
from optuna.visualization import plot_optimization_history, plot_param_importances
from tqdm import tqdm

from examples.utils import create_overlay, sample_to_arrays, get_preprocess_params
from synth_mt.config.synthetic_data import SyntheticDataConfig
from synth_mt.config.tuning import TuningConfig
from synth_mt.data_generation.optimization.embeddings import ImageEmbeddingExtractor
from synth_mt.data_generation.optimization.eval import evaluate_synthetic_data_cfg
from synth_mt.data_generation.optimization.metrics import precompute_matric_args
from synth_mt.data_generation.optimization.objective import objective
from synth_mt.benchmark.models.factory import setup_model_factory
from synth_mt.benchmark.metrics import calculate_segmentation_metrics, calculate_downstream_metrics
from synth_mt.model_hpo.model_hpo import define_search_space, objective_function
from synth_mt.utils import preprocessing as pre
from synth_mt.utils import postprocessing as post

---
# Part 1: Tune Synthetic Data to Real Microscopy Images

The optimization process aligns synthetic image distributions with real, annotation-free microscopy data:
- **Real IRM images** (unlabeled) define the target distribution
- **Synthetic images** are generated by the parametric generator $P_\theta$
- Both are embedded using **DINOv2** (pre-trained vision transformer)
- Parameters $\theta$ are iteratively refined to **maximize cosine similarity**

## 1.1 Load Tuning Configuration

The tuning configuration specifies reference images, search space, and optimization settings.


In [None]:
cfg_path = "tuning_config_example.json"
tuning_cfg = TuningConfig.load(cfg_path)
tuning_cfg.validate()

print(f"Reference images directory: {tuning_cfg.reference_images_dir}")
print(f"Number of optimization trials: {tuning_cfg.num_trials}")

## 1.2 Compute Reference Embeddings

Extract DINOv2 embeddings from real reference images. These define the target distribution.


In [None]:
embedding_extractor = ImageEmbeddingExtractor(tuning_cfg)
ref_embeddings = embedding_extractor.extract_from_references()
precomputed_kwargs = precompute_matric_args(tuning_cfg, ref_embeddings)

print(f"Extracted embeddings shape: {ref_embeddings.shape}")

## 1.3 Setup and Run Optuna Study

We use Optuna with TPE sampler to efficiently search the parameter space.

In [None]:
# Setup storage
db_filename = f"{tuning_cfg.output_config_id}.db"
db_filepath = os.path.join(tuning_cfg.temp_dir, db_filename)
os.makedirs(tuning_cfg.temp_dir, exist_ok=True)
storage_uri = f"sqlite:///{db_filepath}"

# Create study
study_synth = optuna.create_study(
    sampler=optuna.samplers.TPESampler(),
    study_name=tuning_cfg.output_config_id,
    storage=storage_uri,
    direction=tuning_cfg.direction,
    load_if_exists=tuning_cfg.load_if_exists,
)

# Create objective function with pre-computed arguments
objective_fcn = partial(
    objective,
    tuning_cfg=tuning_cfg,
    ref_embeddings=ref_embeddings,
    embedding_extractor=embedding_extractor,
    **precomputed_kwargs,
)

In [None]:
# Run optimization (reduce n_trials for demo purposes)
N_TRIALS_SYNTH = 10  # Increase for better results
study_synth.optimize(objective_fcn, n_trials=N_TRIALS_SYNTH)

print(f"Best trial value: {study_synth.best_value:.4f}")
print(f"Best parameters: {study_synth.best_params}")

## 1.4 Visualize Optimization Results

In [None]:
fig = optuna.visualization.matplotlib.plot_optimization_history(study_synth)
plt.title("Synthetic Data Optimization History")
plt.tight_layout()
plt.show()

In [None]:
fig = optuna.visualization.matplotlib.plot_param_importances(study_synth)
plt.title("Parameter Importances")
plt.tight_layout()
plt.show()

---
# Part 2: Generate Synthetic Samples

Using the best configuration from optimization, generate synthetic images with ground-truth masks.

In [None]:
# Get best configuration
best_trial = study_synth.best_trial
best_cfg = SyntheticDataConfig.from_trial(best_trial)

# Generate 10 samples (single frames)
NUM_SYNTHETIC_SAMPLES = 10
best_cfg.num_frames = NUM_SYNTHETIC_SAMPLES

print(f"Generating {NUM_SYNTHETIC_SAMPLES} synthetic samples...")

synthetic_frames, synthetic_masks = evaluate_synthetic_data_cfg(
    cfg=best_cfg,
    tuning_cfg=tuning_cfg,
    output_dir=None,
    is_for_expert_validation=False,
)

print(f"Generated {len(synthetic_frames)} frames")
print(f"Frame shape: {synthetic_frames[0].shape}")
print(f"Number of masks per frame: {[len(m) for m in synthetic_masks]}")

## 2.1 Visualize Generated Samples

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(20, 8))
axs = axs.flatten()

for idx, ax in enumerate(axs):
    if idx < len(synthetic_frames):
        overlay = create_overlay(synthetic_frames[idx], synthetic_masks[idx])
        ax.imshow(overlay)
        ax.set_title(f"Sample {idx+1} ({len(synthetic_masks[idx])} MTs)")
    ax.axis("off")

plt.suptitle("Generated Synthetic Samples with Instance Masks", fontsize=14)
plt.tight_layout()
plt.show()

---
# Part 3: Tune SAM3Text Hyperparameters

Now we use the synthetic data to optimize SAM3Text hyperparameters. This enables few-shot adaptation of the foundation model to our specific domain.


## 3.1 Prepare Synthetic Dataset for HPO

In [None]:
# Convert synthetic data to the format expected by the HPO objective function
class SyntheticDataset:
    """Simple wrapper to make synthetic data compatible with HPO."""

    def __init__(self, frames, masks):
        self.frames = frames
        self.masks = masks

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        # Return (image, gt_mask_stack, metadata)
        frame = self.frames[idx]
        mask_stack = np.stack([np.array(m) for m in self.masks[idx]], axis=0)
        return frame, mask_stack, {}

    def get_image_path(self, idx):
        return f"synthetic_sample_{idx}"

synthetic_dataset = SyntheticDataset(synthetic_frames, synthetic_masks)
print(f"Synthetic dataset size: {len(synthetic_dataset)}")

## 3.2 Setup Model Factory and HPO Study

In [None]:
factory = setup_model_factory()

MODEL_NAME = "sam3text"
METRIC = "IoU"
USE_SKELETON = True
N_TRIALS_HPO = 20  # Increase for better results

# HPO direction
direction = "maximize" if METRIC == "IoU" else "minimize"

# Get postprocessing ranges from synthetic data
min_area, max_area, min_length, max_length = post.get_area_length_ranges(synthetic_dataset)
postprocessing_props = {
    "min_area": min_area,
    "max_area": max_area,
    "min_length": min_length,
    "max_length": max_length,
}

print(f"Model: {MODEL_NAME}")
print(f"Metric: {METRIC}")
print(f"Postprocessing ranges: area=[{min_area}, {max_area}], length=[{min_length}, {max_length}]")


In [None]:
# Setup Optuna study for model HPO
hpo_db_path = os.path.join(tuning_cfg.temp_dir, f"hpo_{MODEL_NAME}.db")
hpo_storage_uri = f"sqlite:///{hpo_db_path}"

study_hpo = optuna.create_study(
    study_name=f"{MODEL_NAME}_hpo",
    storage=hpo_storage_uri,
    direction=direction,
    load_if_exists=True,
)

## 3.3 Run Hyperparameter Optimization

In [None]:
TEMP_DIR = tuning_cfg.temp_dir
MODEL_DIR = os.path.join(tuning_cfg.temp_dir, "models")
os.makedirs(MODEL_DIR, exist_ok=True)

study_hpo.optimize(
    lambda trial: objective_function(
        trial,
        factory,
        MODEL_NAME,
        synthetic_dataset,
        postprocessing_props,
        METRIC,
        TEMP_DIR,
        MODEL_DIR,
        USE_SKELETON,
    ),
    n_trials=N_TRIALS_HPO,
)

print(f"\nBest {METRIC}: {study_hpo.best_value:.4f}")
print(f"Best parameters:")
for key, value in study_hpo.best_params.items():
    print(f"  {key}: {value}")

## 3.4 Visualize HPO Results

In [None]:
fig = optuna.visualization.matplotlib.plot_optimization_history(study_hpo)
plt.title(f"SAM3Text HPO - {METRIC} Optimization History")
plt.tight_layout()
plt.show()

In [None]:
fig = optuna.visualization.matplotlib.plot_param_importances(study_hpo)
plt.title("SAM3Text Hyperparameter Importances")
plt.tight_layout()
plt.show()

---
# Part 4: Evaluate the Optimized Model

Now we evaluate the optimized SAM3Text model on the synthetic data using both segmentation and downstream metrics.


## 4.1 Create Model with Best Parameters

In [None]:
# Create model with optimized parameters
best_params = study_hpo.best_params.copy()
best_params["save_dir"] = MODEL_DIR
best_params["work_dir"] = TEMP_DIR

optimized_model = factory.create_model(MODEL_NAME, **best_params)
optimized_model.load_model()

print(f"Loaded optimized {MODEL_NAME} model")

## 4.2 Run Predictions

In [None]:
preprocess_params = get_preprocess_params(optimized_model)

all_gt_masks = []
all_pred_masks = []
all_images = []

for idx in tqdm(range(len(synthetic_dataset)), desc=f"Running {MODEL_NAME}"):
    image, gt_masks, _ = synthetic_dataset[idx]

    processed_image = pre.process_image(image, **preprocess_params)
    pred_mask = optimized_model.predict(processed_image)

    if pred_mask is None:
        print(f"Warning: Model returned None for sample {idx}. Skipping.")
        continue

    all_images.append(image)
    all_gt_masks.append(gt_masks)
    all_pred_masks.append(pred_mask)

print(f"Completed predictions on {len(all_images)} images.")

## 4.3 Visualize Predictions

In [None]:
n_samples = min(5, len(all_images))
fig, axs = plt.subplots(n_samples, 3, figsize=(15, 5 * n_samples))

for sample_idx in range(n_samples):
    gt_overlay = create_overlay(all_images[sample_idx], all_gt_masks[sample_idx])
    pred_overlay = create_overlay(all_images[sample_idx], all_pred_masks[sample_idx])

    axs[sample_idx, 0].imshow(all_images[sample_idx])
    axs[sample_idx, 0].set_title(f"Sample {sample_idx}: Original")

    axs[sample_idx, 1].imshow(gt_overlay)
    axs[sample_idx, 1].set_title(f"Ground Truth ({all_gt_masks[sample_idx].shape[0]} instances)")

    axs[sample_idx, 2].imshow(pred_overlay)
    axs[sample_idx, 2].set_title(f"Prediction ({len(np.unique(all_pred_masks[sample_idx])) - 1} instances)")

    for ax in axs[sample_idx]:
        ax.axis("off")

plt.suptitle(f"Optimized {MODEL_NAME} Predictions vs Ground Truth", fontsize=14)
plt.tight_layout()
plt.show()

## 4.4 Calculate Segmentation Metrics

Compute instance segmentation metrics including:
- **Precision, Recall, F1** at various IoU thresholds
- **Skeletonized IoU (SkIoU)** - IoU computed on skeletonized masks, better suited for filamentous structures
- **Average Precision (AP)** across IoU thresholds


In [None]:
mean_metrics, std_metrics = calculate_segmentation_metrics(
    gt_masks=all_gt_masks,
    pred_masks=all_pred_masks,
    use_skeletonized_version=True,
)

print(f"\n{'='*60}")
print(f"Segmentation Metrics for Optimized {MODEL_NAME}")
print(f"{'='*60}")
print(f"{'Metric':<30} {'Mean':>12} {'Std':>12}")
print(f"{'-'*60}")
for key in sorted(mean_metrics.keys()):
    std_val = std_metrics.get(key, 0.0)
    print(f"{key:<30} {mean_metrics[key]:>12.4f} {std_val:>12.4f}")

## 4.5 Calculate Downstream Metrics

Compute biologically relevant downstream metrics:
- **Count Error** - Difference in number of detected microtubules
- **Length Distribution** - KL divergence between predicted and ground truth length distributions
- **Curvature Distribution** - KL divergence between predicted and ground truth curvature distributions


In [None]:
downstream_metrics = calculate_downstream_metrics(
    gt_masks=all_gt_masks,
    pred_masks=all_pred_masks,
    pixel_per_micrometer=9.0,
)

print(f"\n{'='*60}")
print(f"Downstream Metrics for Optimized {MODEL_NAME}")
print(f"{'='*60}")
print(f"{'Metric':<40} {'Value':>15}")
print(f"{'-'*60}")
for key, value in downstream_metrics.items():
    if isinstance(value, float):
        print(f"{key:<40} {value:>15.4f}")
    else:
        print(f"{key:<40} {str(value):>15}")

---
# Summary

This notebook demonstrated the complete SynthMT pipeline:

1. **Synthetic Data Tuning**: Optimized generation parameters $\theta$ to match real IRM images using DINOv2 embeddings
2. **Sample Generation**: Created 10 synthetic samples with ground-truth instance masks
3. **Model HPO**: Tuned SAM3Text hyperparameters using only synthetic data
4. **Evaluation**: Computed segmentation (IoU, SkIoU, AP) and downstream metrics (count, length, curvature)

Key takeaway: By tuning synthetic data to match real microscopy images, we can effectively adapt foundation models like SAM3Text to the microtubule segmentation domain without requiring manual annotations on real data.


In [None]:
# Save best parameters for future use
import json

output_config = {
    "synthetic_data_params": study_synth.best_params,
    "model_hpo_params": study_hpo.best_params,
    "best_synth_score": study_synth.best_value,
    "best_model_score": study_hpo.best_value,
}

output_path = os.path.join(tuning_cfg.temp_dir, "full_pipeline_results.json")
with open(output_path, "w") as f:
    json.dump(output_config, f, indent=2)

print(f"Results saved to: {output_path}")