In [4]:
%load_ext autoreload
%autoreload 2

### 3D CNN Instance Segmentation of Proteins in Cryo-ET Tomograms

This tutorial guides you through the process of training 3D U-Nets for instance segmentation of proteins in Cryo-ET tomograms. It draws inspiration from the segmentation framework introduced by E. Moebel and co-authors in DeepFinder. However, this repository introduces new developments in model architectures, data augmentations, and efficient model training, with support for datasets available on both local and remote resources.

#### 🧐 What you'll learn
In this notebook, we demonstrate how to use the octopi workflow to predict 3D protein coordinates from segmentation masks. We illustrate this using Dataset ID: [10440](https://cryoetdataportal.czscience.com/datasets/10440) — a benchmark dataset provided as part of the CZ Imaging Institute Machine Learning Challenge.

This dataset includes six experimental tomgorams annotated with known macromolecular species: 

* Apoferritin 
* Beta-amylase 
* Beta-galactosidase 
* Ribosome 
* Thyroglobulin 
* Virus-like particles (VLP)

To learn more about the dataset and challenge, see the full preprint here: 📄 [A Machine Learning Challenge for the Instance Segmentation of Proteins in Cryo-ET](https://www.biorxiv.org/content/biorxiv/early/2024/11/21/2024.11.04.621686.full.pdf)

### 📚 Tutorial Overview
The tutorial is structured into two main components:

1. Data Preparation: Generating target volumes that the network will use to predict coordinates.
2. Model Training: Training the 3D U-Net model.
3. Optuna Optimization (Optional): Explore Several Model Configurations with Bayesian Optimization

**Note:** Inference is provided in the following notebook: `inference.ipynb`

By following this tutorial, you will gain insights into preparing data, training a 3D U-Net model for the instance segmentation of proteins in Cryo-ET tomograms.


#### 🧱 Step 1: Data Preparation: Generate Targets for Training

In this step, we will prepare the target data necessary for training our model and predicting the coordinates of proteins within a tomogram.

We will use the Copick tool to manage the filesystem, extract tomogram IDs, and create spherical targets corresponding to the locations of proteins. The key tasks performed in this cell include:

* **Loading Parameters:** 

We define the size of the target spheres, specify the copick path, voxel size, target file name, and user ID.
* **Generating Targets:**

For each tomogram, we extract particle coordinates, reset the target volume, generate spherical targets based on these coordinates, and save the target data in OME Zarr format. The equivalent CLI tool for this step is:
```
octopi create-targets --help
```

##### 💡 Notes:
* **Data Access via [copick](https://github.com/copick/copick):**

octopi assumes that tomograms and coordinates are accessible through the copick configuration system.

* **Alternative Input Sources:**

If your data is stored in a folder as `*.mrc` volumes (e.g., from another processing pipeline), you can import them using:
```
octopi import-mrc-volumes --help
```

* **Download from the [Data-Portal](https://cryoetdataportal.czscience.com)**

We can also download tomograms from the data-portal to speed up processing by avoiding runtime downloads, you can fetch tomograms in advance:
```
octopi download-dataportal --help
```
* **Recommended Resolution:**

Tomogarms should ideally be resampled to at least 10 Å per voxel. This reduces memory usage and speeds up training without significantly sacrificing performance. When import data from either MRC formats, or downloading directly from the data-portal we can downsample to the desired resolution with the `--output-voxel-size` flag. 


In [5]:
from octopi.entry_points.run_create_targets import create_sub_train_targets, create_all_train_targets

# Copick Config
config = '../config.json'

# Target Parameters
target_name = 'targets'
target_user_id = 'octopi'               # These parameters are optional
target_session_id = '0'

# Tomogram Query Information - This is Used to determine the resolution that the targets will be created for. 
voxel_size = 10.012
tomogram_algorithm = 'wbp-denoised-denoiset-ctfdeconv'

# For our segmetnation target, we can create a sphere with a diameter that is a fraction of the 
# particle radius provided in the config file.
radius_scale = 0.7

# Optional: Define A Sub-set of tomograms for generating training labels
run_ids = None

To generate the segmentation targets, we can use to optional functions that are available. 
1. We can provide a subset of pickable objects and (optionally) its userID / sessionIds. This allows for creating training targets from varying submission sources.
2. Instead of Manually specifying each individual pick targets by the name (and potentially its sessionID and/or userID). We can find all the pickable objects associated with a single query. 


In [None]:
# Option 1: We can provide a subset of pickable objects and (optionally) its userID / sessionIds. 
# This allows for creating training targets from varying submission sources.
# Provide inputs as a list of tuples -> [ (name, userID, sessionI)]

pick_targets = [
    ('ribosome', 'data-portal', None),
    ('virus-like-particle', 'data-portal', None),
    ('apoferritin', 'data-portal', None)
]

seg_targets = [] # Either provide this variable as an empty list or populate entries in the same format (name, userID, sessionID)

create_sub_train_targets(
    config, pick_targets, seg_targets, voxel_size, radius_scale, tomogram_algorithm,
    target_name, target_user_id, target_session_id, run_ids
)

Creating Targets for the following objects: ribosome, virus-like-particle, apoferritin


  0%|          | 0/7 [00:10<?, ?it/s]

Annotating 88 picks in 16463...


 14%|█▍        | 1/7 [00:19<01:11, 11.87s/it]

Annotating 81 picks in 16464...


 29%|██▊       | 2/7 [00:28<00:52, 10.41s/it]

Annotating 142 picks in 16465...


 43%|████▎     | 3/7 [00:37<00:38,  9.75s/it]

Annotating 83 picks in 16466...


 57%|█████▋    | 4/7 [00:47<00:28,  9.59s/it]

Annotating 163 picks in 16467...


 71%|███████▏  | 5/7 [00:56<00:19,  9.61s/it]

Annotating 148 picks in 16468...


 86%|████████▌ | 6/7 [00:58<00:09,  9.61s/it]

In [3]:
# Option 2: Instead of Manually Specifying Each pickable object, we can provide a single query 
# and it will grab the first available coordinate for each pickable object.
picks_user_id = 'data-portal'
picks_session_id = None

# In this case, we don't have any organelle segmentations that are at 10 Angstroms on the portal
seg_targets = []

create_all_train_targets(
    config, seg_targets, picks_session_id, picks_user_id, 
    voxel_size, radius_scale, tomogram_algorithm, 
    target_name, target_user_id, target_session_id, run_ids
)

Creating Targets for the following objects: apoferritin, beta-amylase, beta-galactosidase, ribosome, thyroglobulin, virus-like-particle, membrane


  0%|          | 0/7 [00:10<?, ?it/s]

Annotating 136 picks in 16463...


 14%|█▍        | 1/7 [00:19<01:12, 12.13s/it]

Annotating 141 picks in 16464...


 29%|██▊       | 2/7 [00:28<00:50, 10.10s/it]

Annotating 191 picks in 16465...


 43%|████▎     | 3/7 [00:37<00:38,  9.57s/it]

Annotating 143 picks in 16466...


 57%|█████▋    | 4/7 [00:46<00:28,  9.64s/it]

Annotating 215 picks in 16467...


 71%|███████▏  | 5/7 [00:55<00:18,  9.36s/it]

Annotating 221 picks in 16468...


 86%|████████▌ | 6/7 [01:04<00:09,  9.16s/it]

Annotating 202 picks in 16469...


100%|██████████| 7/7 [01:06<00:00,  9.43s/it]

Creation of targets complete!





#### Step 2: Training the octopi 🐙 to find macromolecules in Cryo-ET Tomograms

Once our target labels are prepared, we can begin training a deep learning model to identify macromolecular structures in our data. 

The training process is modular and configurable. It involves defining a target segmentation volumes (prepared in Step 1), preparing 3D tomographic input data, and configuring a U-Net-based segmentation model to predict voxel-level class assignments.

In [1]:
from monai.metrics import ConfusionMatrixMetric
from octopi.models import common as builder
from octopi.datasets import generators
from monai.losses import TverskyLoss
from octopi import losses
from octopi.pytorch import trainer 
from octopi import io, utils
import torch, os

########### Input Parameters ###########

# Target Parameters
config = "../config.json"
target_name = 'targets'
target_user_id = 'octopi'
target_session_id = None

# DataGenerator Parameters
num_tomo_crops = 16
tomo_algorithm = 'wbp-denoised-denoiset-ctfdeconv'
voxel_size = 10.012
# In cases where all the tomograms can't be fit in memory, we can train on smaller batches
tomo_batch_size = 25

# Model Parameters
Nclass = 7
model_config = {
        'architecture': 'Unet',            # Model Architecture
        'channels': [32,64,128,128],   # Number of Channels in Each Layer 
        'strides': [2, 2, 1, 1],        # Strides for the convolutional layers
        'num_res_units': 3,                # Number of Residual units
        'num_classes': Nclass,                  # Number of Classes on prediction head (background + numClasses)
        'dropout': 0.05,                    # Drop Out
        'dim_in': 128                      # Input Dimensions [voxels]
    }

model_save_path = 'results'         # Path to save the model
model_weights = None # Path to the pre-trained model weights

# Optional - Specify RunIDs for training and validation data splits. 
trainRunIDs = None
validateRunIDs = None

#### 🧪 Prepare the training module

Next, we instantiate the octopi data generator, which handles on-the-fly loading of sub-volumes from the full tomograms. This is especially helpful when training on large datasets that cannot fit into memory.

We also define the custom loss and metric functions. Here we use a Weighted Focal Tversky Loss, which is well-suited for class-imbalanced volumetric data, and a multi-class confusion matrix metric to compute recall, precision, and F1 score per class.

In [15]:
# Single-config training
data_generator = generators.TrainLoaderManager(
    config, 
    target_name, 
    target_session_id = target_session_id,
    target_user_id = target_user_id,
    tomo_algorithm = tomo_algorithm,
    voxel_size = voxel_size,
    Nclasses = Nclass,
    tomo_batch_size = tomo_batch_size )

# Get the data splits
data_generator.get_data_splits(trainRunIDs = trainRunIDs,
                                validateRunIDs = validateRunIDs,
                                train_ratio = 0.9, val_ratio = 0.1, test_ratio = 0.0,
                                create_test_dataset=False)

# Get the reload frequency
data_generator.get_reload_frequency(num_epochs)

# Monai Functions
alpha0 = 0.1
gamma0 = 1.88
weight = 0.13
loss_function = losses.WeightedFocalTverskyLoss(
    gamma=gamma0, alpha = alpha0, beta = (1-alpha0),
    weight_tversky = weight, weight_focal = (1-weight)
)
metrics_function = ConfusionMatrixMetric(include_background=False, metric_name=["recall",'precision','f1 score'], reduction="none")

# Create UNet Model and Load Weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_builder = builder.get_model(model_config['architecture'])
model = model_builder.build_model(model_config)
if model_weights: 
    model.load_state_dict(torch.load(model_weights, weights_only=True))
model.to(device)

# Optimizer
lr = 1e-3   # Learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=0.1)

# Create UNet-Trainer
model_trainer = trainer.ModelTrainer(model, device, loss_function, metrics_function, optimizer)

Number of training samples: 6
Number of validation samples: 1
Number of test samples: 0
All training samples fit in memory. No reloading required.


#### 🏋️ Train the model

Finally, we initiate model training for a user-defined number of epochs. Validation is run at regular intervals (val_interval), and the best-performing model is tracked based on a specified metric (avg_fBeta by default).

Training results and metadata are saved to disk at the end for future analysis and reproducibility.

In [None]:
# Training Parameters and Frequency to Evaluate Validation Dataset
num_epochs = 1000          # Number of epochs to train
val_interval = 10          # Number of epochs between validation

# Metrics for Saving Checkings 
# Options: (fBetaN, avg_{metric}, or {metric}_class{N} 
# where metric = [recall, precision, f1 or fBetaN]) 
best_metric = 'fBeta3'  

results = model_trainer.train(
    data_generator, model_save_path, max_epochs=num_epochs,
    crop_size=model_config['dim_in'], my_num_samples=num_tomo_crops,
    val_interval=val_interval, best_metric=best_metric, verbose=True
)

# Save parameters and results
parameters_save_name = os.path.join(model_save_path, "training_parameters.yaml")
io.save_parameters_to_yaml(model_builder, model_trainer, data_generator, parameters_save_name)

results_save_name = os.path.join(model_save_path, "results.json")
io.save_results_to_json(results, results_save_name)


avg_fBeta is not a valid metric! Tracking avg_f1 as the best metric



Loading dataset: 100%|██████████| 6/6 [00:01<00:00,  3.43it/s]
Loading dataset: 100%|██████████| 1/1 [00:00<00:00,  3.44it/s]
Training Progress:   0%|          | 4/1000 [00:28<1:34:53,  5.72s/epoch]

Epoch 5/1000, avg_train_loss: 0.9194


Training Progress:   0%|          | 4/1000 [00:30<1:34:53,  5.72s/epoch]

Epoch 5/1000, avg_f1_score: 0.0123, avg_recall: 0.2163, avg_precision: 0.0064


Training Progress:   1%|          | 9/1000 [00:59<1:36:37,  5.85s/epoch]

Epoch 10/1000, avg_train_loss: 0.9001


Training Progress:   1%|          | 9/1000 [01:01<1:36:37,  5.85s/epoch]

Epoch 10/1000, avg_f1_score: 0.0209, avg_recall: 0.2545, avg_precision: 0.0110


Training Progress:   1%|▏         | 14/1000 [01:29<1:36:25,  5.87s/epoch]

Epoch 15/1000, avg_train_loss: 0.8819


Training Progress:   1%|▏         | 14/1000 [01:31<1:36:25,  5.87s/epoch]

Epoch 15/1000, avg_f1_score: 0.0231, avg_recall: 0.2833, avg_precision: 0.0122


Training Progress:   2%|▏         | 19/1000 [02:00<1:36:24,  5.90s/epoch]

Epoch 20/1000, avg_train_loss: 0.8642


Training Progress:   2%|▏         | 19/1000 [02:02<1:36:24,  5.90s/epoch]

Epoch 20/1000, avg_f1_score: 0.0246, avg_recall: 0.3030, avg_precision: 0.0129


Training Progress:   2%|▏         | 24/1000 [02:32<1:36:38,  5.94s/epoch]

Epoch 25/1000, avg_train_loss: 0.8386


Training Progress:   2%|▏         | 24/1000 [02:33<1:36:38,  5.94s/epoch]

Epoch 25/1000, avg_f1_score: 0.0297, avg_recall: 0.3121, avg_precision: 0.0157


Training Progress:   3%|▎         | 29/1000 [03:02<1:35:17,  5.89s/epoch]

Epoch 30/1000, avg_train_loss: 0.8145


Training Progress:   3%|▎         | 29/1000 [03:04<1:35:17,  5.89s/epoch]

Epoch 30/1000, avg_f1_score: 0.0387, avg_recall: 0.3556, avg_precision: 0.0206


Training Progress:   3%|▎         | 34/1000 [03:34<1:35:13,  5.91s/epoch]

Epoch 35/1000, avg_train_loss: 0.7773


Training Progress:   3%|▎         | 34/1000 [03:36<1:35:13,  5.91s/epoch]

Epoch 35/1000, avg_f1_score: 0.0510, avg_recall: 0.4518, avg_precision: 0.0273


Training Progress:   4%|▍         | 39/1000 [04:04<1:35:09,  5.94s/epoch]

Epoch 40/1000, avg_train_loss: 0.7452


Training Progress:   4%|▍         | 39/1000 [04:06<1:35:09,  5.94s/epoch]

Epoch 40/1000, avg_f1_score: 0.0606, avg_recall: 0.4608, avg_precision: 0.0327


Training Progress:   4%|▍         | 44/1000 [04:35<1:32:21,  5.80s/epoch]

Epoch 45/1000, avg_train_loss: 0.7321


Training Progress:   4%|▍         | 44/1000 [04:37<1:32:21,  5.80s/epoch]

Epoch 45/1000, avg_f1_score: 0.0736, avg_recall: 0.4584, avg_precision: 0.0405


Training Progress:   5%|▍         | 48/1000 [04:54<1:32:58,  5.86s/epoch]

#### 🔁 (Optional): Use Optuna / Bayesian Optimization for Automatic Network Exploration

In this optional step, we use Optuna, a Bayesian optimization framework, to automatically explore different network architectures and training hyperparameters. This process helps identify high-performing configurations without the need for exhaustive manual tuning.

By leveraging intelligent sampling strategies, Optuna can efficiently search through:

	•	Network depth and width (e.g., number of layers, channels)
	•	Learning rates, dropout rates, and other optimization parameters
	•	Loss function weights (e.g., Focal vs Tversky balance)
	•	Data sampling or augmentation strategies

This automated search is especially useful when working with new biological targets with unknown optimal network setups.

To run the model search outside this notebook, you can use the CLI:
```
octopi model-explore --help
```

In [2]:
from octopi.pytorch.model_search_submitter import ModelSearchSubmit

#########################Input Parameters#########################

# Target Parameters
config = "../config.json"
target_name = 'targets'
target_user_id = 'octopi'
target_session_id = None
tomo_algorithm = 'wbp-denoised-denoiset-ctfdeconv'
voxel_size = 10.012
Nclass = 7

# Define the model type
model_type = "Unet"

# Training and Exploration Parameters
num_trials = 100
num_epochs = 1000
tomo_batch_size = 25
best_metric = 'fBeta3'
val_interval = 10

# MLFlow Experiment Name
mlflow_experiment_name = "model_search"

# Random Seed
random_seed = 42

# Define the train and validate run IDs
trainRunIDs = None
validateRunIDs = None

# Initialize the ModelSearchSubmit class
search = ModelSearchSubmit(
    config,
    target_name, target_user_id, target_session_id,
    tomo_algorithm, voxel_size, Nclass, model_type,
    mlflow_experiment_name, random_seed, 
    num_epochs, num_trials, tomo_batch_size, best_metric, val_interval,
    trainRunIDs, validateRunIDs
)


Training with:
  ../config.json

Number of training samples: 5
Number of validation samples: 2
Number of test samples: 0
All training samples fit in memory. No reloading required.


Run the Search Process

In [None]:
search.run_model_search()