# Data Scientist (DS) - Federated SAM2 Medical Image Segmentation

This notebook is for **Data Scientists** who want to train a federated SAM2 model for medical image segmentation across multiple hospitals/data owners.

## What this notebook does:
1. **Login** as a Data Scientist using your Google account
2. **Add peers** (Data Owners/hospitals) to collaborate with
3. **Explore datasets** available at each Data Owner site
4. **Submit** FL training jobs to Data Owners
5. **Coordinate** federated learning and aggregate LoRA adapters
6. **Save** the final aggregated model

## Key Benefits:
- Train on data from multiple hospitals without accessing raw data
- Only lightweight LoRA adapters (~2-8 MB) are transferred
- Privacy-preserving medical AI

## Prerequisites
1. Go to https://colab.research.google.com/
2. Upload this notebook with `File` -> `Upload Notebook`

## Install Dependencies

In [None]:
# Install syft-flwr from the development branch
!uv pip install -v "git+https://github.com/OpenMined/syft-flwr.git@feat/syft-client-p2p"

## Login as Data Scientist

Login using your Google account. This will:
- Authenticate with Google Drive
- Create your SyftBox folder structure
- Enable P2P communication with Data Owners

In [None]:
import syft_client as sc
import syft_flwr

print(f"{sc.__version__ = }")
print(f"{syft_flwr.__version__ = }")

ds_email = input("Enter the Data Scientist's email: ")
ds_client = sc.login_ds(email=ds_email)

## Add Peer Data Owners

Add the Data Owners (hospitals) you want to collaborate with.
They must also add you as a peer from their side.

In [None]:
# Add first Data Owner (Hospital 1)
do1_email = input("Enter the First Data Owner's email: ")
ds_client.add_peer(do1_email)

print(f"\nCurrent peers:")
ds_client.peers

In [None]:
# Add second Data Owner (Hospital 2) - Optional
# do2_email = input("Enter the Second Data Owner's email: ")
# ds_client.add_peer(do2_email)
# ds_client.peers

## Explore Datasets

View available medical imaging datasets at each Data Owner site.
You can only see metadata and mock (anonymized) samples, not the actual private data.

In [None]:
# View datasets from DO1
do1_datasets = ds_client.datasets.get_all(datasite=do1_email)
do1_datasets

In [None]:
# View dataset details
if len(do1_datasets) > 0:
    do1_datasets[0].describe()

In [None]:
# View datasets from DO2 (if added)
# do2_datasets = ds_client.datasets.get_all(datasite=do2_email)
# do2_datasets

## Setup FL Project

Mount Google Drive to access the FL project code.

In [None]:
from google.colab import drive
from pathlib import Path

drive.mount('/content/drive')

# FL project location in your Google Drive
SYFT_FLWR_PROJECT_PATH = Path("/content/drive/MyDrive/fl-sam2-segmentation")

# If project doesn't exist, copy from repository
if not SYFT_FLWR_PROJECT_PATH.exists():
    print(f"Project not found at {SYFT_FLWR_PROJECT_PATH}")
    print("Please copy the fl-sam2-segmentation folder to your Google Drive")
else:
    print(f"Project exists: {SYFT_FLWR_PROJECT_PATH.exists()} at {str(SYFT_FLWR_PROJECT_PATH)}")

## Bootstrap FL Project

Configure the FL project with the aggregator (you) and datasites (Data Owners).

In [None]:
import syft_flwr

try:
    # Clean up old main.py if exists
    main_py = SYFT_FLWR_PROJECT_PATH / "main.py"
    if main_py.exists():
        main_py.unlink()
    
    print(f"syft_flwr version = {syft_flwr.__version__}")
    
    # Get all peer emails
    do_emails = [peer.email for peer in ds_client.peers]
    print(f"Data Owners: {do_emails}")
    print(f"Aggregator: {ds_email}")
    
    # Bootstrap the project
    syft_flwr.bootstrap(
        SYFT_FLWR_PROJECT_PATH, 
        aggregator=ds_email, 
        datasites=do_emails
    )
    print("Bootstrapped project successfully")
except Exception as e:
    print(f"Error: {e}")

## Run Local Simulation (Optional)

Test the FL pipeline locally with mock data before submitting to real Data Owners.

In [None]:
RUN_SIMULATION = False  # Set to True to run simulation

if RUN_SIMULATION:
    # Clean pycache
    import shutil
    pycache = SYFT_FLWR_PROJECT_PATH / "fl_sam2_segmentation" / "__pycache__"
    if pycache.exists():
        shutil.rmtree(pycache)
    
    # Get mock dataset URLs
    mock_dataset_urls = []
    for do_email in do_emails:
        datasets = ds_client.datasets.get_all(datasite=do_email)
        if datasets:
            mock_dataset_urls.append(datasets[0].mock_url)
    
    print(f"Running simulation with mock paths: {mock_dataset_urls}")
    syft_flwr.run(SYFT_FLWR_PROJECT_PATH, mock_dataset_urls)

## Submit FL Job to Data Owners

Submit the FL training job to each Data Owner.
They will need to approve and run the job on their local data.

In [None]:
# Clean pycache before submitting
import shutil
pycache = SYFT_FLWR_PROJECT_PATH / "fl_sam2_segmentation" / "__pycache__"
if pycache.exists():
    shutil.rmtree(pycache)

# Submit to first Data Owner
ds_client.submit_python_job(
    user=do1_email,
    code_path=str(SYFT_FLWR_PROJECT_PATH),
    job_name="fl-sam2-segmentation-training",
)
print(f"Job submitted to {do1_email}")

In [None]:
# Submit to second Data Owner (if added)
# ds_client.submit_python_job(
#     user=do2_email,
#     code_path=str(SYFT_FLWR_PROJECT_PATH),
#     job_name="fl-sam2-segmentation-training",
# )
# print(f"Job submitted to {do2_email}")

In [None]:
# Check submitted jobs
ds_client.jobs

## Run FL Server (Aggregator)

Run the FL server to:
1. Wait for Data Owners to complete their local training
2. Aggregate LoRA adapter weights using FedAvg
3. Coordinate multiple training rounds
4. Save the final aggregated model

In [None]:
# Install additional dependencies for SAM2
!uv pip install \
    "torch>=2.0.0" \
    "torchvision>=0.15.0" \
    "peft>=0.5.0" \
    "opencv-python>=4.8.0" \
    "Pillow>=10.0.0" \
    "numpy>=1.24.0" \
    "scipy>=1.10.0" \
    "scikit-learn>=1.3.0" \
    "pandas>=2.0.0" \
    "loguru>=0.7.0" \
    "tqdm>=4.65.0"

In [None]:
# Run the FL server (aggregator)
assert SYFT_FLWR_PROJECT_PATH.exists(), f"Project path does not exist: {SYFT_FLWR_PROJECT_PATH}"
assert (SYFT_FLWR_PROJECT_PATH / "main.py").exists(), f"main.py not found"

syftbox_folder = f"/content/SyftBox_{ds_email}"

!SYFTBOX_EMAIL="{ds_email}" SYFTBOX_FOLDER="{syftbox_folder}" \
    uv run {str(SYFT_FLWR_PROJECT_PATH / "main.py")}

## View Results

After training completes, view the aggregated model and metrics.

In [None]:
# Check for saved model weights
import os

output_dir = Path.home() / ".syftbox/rds/"
weights_path = output_dir / "sam2_lora_weights"

if weights_path.exists():
    print(f"Model weights saved at: {weights_path}")
    print(f"Files:")
    for f in weights_path.iterdir():
        print(f"  - {f.name}")
else:
    print("No model weights found yet. Training may still be in progress.")

In [None]:
# Load and inspect the final model
import torch
from fl_sam2_segmentation.task import create_model

# Create model architecture
model = create_model(img_size=512, lora_rank=8)

# Load aggregated weights
final_weights_path = weights_path / "final_weights.pt"  # or latest round
if final_weights_path.exists():
    weights = torch.load(final_weights_path, map_location="cpu")
    model.load_adapter_state_dict(weights)
    print("Loaded aggregated weights!")
    
    # Count parameters
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable parameters: {trainable:,}")

## Clean Up

Optional: Delete the SyftBox folder from your Drive when done.

In [None]:
# WARNING: This will delete your SyftBox folder!
# Uncomment to run:
# ds_client.delete_syftbox()

## Debug (Optional)

View your SyftBox folder structure for debugging.

In [None]:
!sudo apt install tree -qq

In [None]:
!tree ./drive/MyDrive/SyftBox