<a href="https://colab.research.google.com/github/AdaptiveMotorControlLab/CellSeg3d/blob/main/notebooks/Colab_WNet3D_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **WNet3D: self-supervised 3D cell segmentation**

---

This notebook is part of the [CellSeg3D project](https://github.com/AdaptiveMotorControlLab/CellSeg3d) in the [Mathis Lab of Adaptive Intelligence](https://www.mackenziemathislab.org/).

- 💜 The foundation of this notebook owes much to the **[ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic)** project and to the **[DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)** team for bringing Colab into scientific open software.

#**1. Installing dependencies**
---

In [1]:
#@markdown ##Play to install WNet dependencies
# !pip install napari-cellseg3d

##**1.2 Load key dependencies**
---

In [1]:
# @title
from pathlib import Path
from napari_cellseg3d.dev_scripts import colab_training as c
from napari_cellseg3d.config import WNetTrainingWorkerConfig, WandBConfig, WeightsInfo, PRETRAINED_WEIGHTS_DIR

## (optional) **1.3 Initialize Weights & Biases integration **
---
If you wish to utilize Weights & Biases (WandB) for monitoring and logging your training session, execute the cell below.
To enable it, just input your API key in the space provided.

# **2. Complete the Colab session**
---



## **2.1. Check for GPU access**
---

By default, this session is configured to use Python 3 and GPU acceleration. To verify or adjust these settings:

<font size = 4>Navigate to Runtime and select Change the Runtime type.

<font size = 4>For Runtime type, ensure it's set to Python 3 (the programming language this program is written in).

<font size = 4>Under Accelerator, choose GPU (Graphics Processing Unit).


In [2]:
#@markdown ##Execute the cell below to verify if GPU access is available.

import torch
if not torch.cuda.is_available():
  print('You do not have GPU access.')
  print('Did you change your runtime?')
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi


You do not have GPU access.
Did you change your runtime?
If the runtime setting is correct then Google did not allocate a GPU for your session
Expect slow performance. To access GPU try reconnecting later


In [3]:
import os

print(os.environ.get('PATH'))

/home/th3129/.conda/envs/napari_CellSeg3D_ARM64/bin:/home/th3129/.vscode-server/cli/servers/Stable-dc96b837cf6bb4af9cd736aa3af08cf8279f7685/server/bin/remote-cli:/share/apps/singularity/bin:/share/apps/anaconda3/2020.07/bin:/share/apps/anaconda3/2020.07/condabin:/home/th3129/.local/bin:/home/th3129/bin:/share/apps/singularity/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/share/apps/local/bin:/usr/lpp/mmfs/bin:/opt/slurm/bin:/share/apps/local/bin:/home/th3129/.vscode-server/cli/servers/Stable-dc96b837cf6bb4af9cd736aa3af08cf8279f7685/server/bin/remote-cli:/share/apps/singularity/bin:/share/apps/anaconda3/2020.07/bin:/share/apps/anaconda3/2020.07/condabin:/home/th3129/.local/bin:/home/th3129/bin:/share/apps/singularity/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/share/apps/local/bin:/usr/lpp/mmfs/bin:/opt/slurm/bin:/share/apps/local/bin


In [4]:
torch.cuda.is_available()

False

## **2.2. Mount Google Drive**
---
<font size = 4>To integrate this notebook with your personal data, save your data on Google Drive in accordance with the directory structures detailed in Section 0.

1. <font size = 4> **Run** the **cell** below and click on the provided link.

2. <font size = 4>Log in to your Google account and grant the necessary permissions by clicking 'Allow'.

3. <font size = 4>Copy the generated authorization code and paste it into the cell, then press 'Enter'. This grants Colab access to read and write data to your Google Drive.

4. <font size = 4> After completion, you can view your data in the notebook. Simply click the Files tab on the top left and select 'Refresh'.

In [4]:
# # mount user's Google Drive to Google Colab.
# from google.colab import drive
# drive.mount('/content/gdrive')

**<font size = 4> If you cannot see your files, reactivate your session by connecting to your hosted runtime.**


<img width="40%" alt ="Example of image detection with retinanet." src="https://github.com/HenriquesLab/ZeroCostDL4Mic/raw/master/Wiki_files/connect_to_hosted.png"><figcaption> Connect to a hosted runtime. </figcaption>

In [5]:
# @title
# import wandb
# wandb.login()

# **3. Select your parameters and paths**
---

## **3.1. Choosing parameters**

---

### **Paths to the training data and model**

* <font size = 4>**`training_source`** specifies the paths to the training data. They must be a single multipage TIF file each

* <font size = 4>**`model_path`** specifies the directory where the model checkpoints will be saved.

<font size = 4>**Tip:** To easily copy paths, navigate to the 'Files' tab, right-click on a folder or file, and choose 'Copy path'.

### **Training parameters**

* <font size = 4>**`number_of_epochs`** is the number of times the entire training data will be seen by the model. Default: 50

* <font size = 4>**`batchs_size`** is the number of image that will be bundled together at each training step. Default: 4

* <font size = 4>**`learning_rate`** is the step size of the update of the model's weight. Try decreasing it if the NCuts loss is unstable. Default: 2e-5

* <font size = 4>**`num_classes`** is the number of brightness clusters to segment the image in. Try raising it to 3 if you have artifacts or "halos" around your cells that have significantly different brightness. Default: 2

* <font size = 4>**`weight_decay`** is a regularization parameter used to prevent overfitting. Default: 0.01

* <font size = 4>**`validation_frequency`** is the frequency at which the provided evaluation data is used to estimate the model's performance.

* <font size = 4>**`intensity_sigma`** is the standard deviation of the feature similarity term. Default: 1

* <font size = 4>**`spatial_sigma`** is the standard deviation of the spatial proximity term. Default: 4

* <font size = 4>**`ncuts_radius`** is the radius for the NCuts loss computation, in pixels. Default: 2

* <font size = 4>**`rec_loss`** is the loss to use for the decoder. Can be Mean Square Error (MSE) or Binary Cross Entropy (BCE). Default : MSE

* <font size = 4>**`n_cuts_weight`** is the weight of the NCuts loss in the weighted sum for the backward pass. Default: 0.5
* <font size = 4>**`rec_loss_weight`** is the weight of the reconstruction loss. Default: 0.005


In [6]:
#@markdown ###Path to the training data:
training_source = "/scratch/th3129/wormID/datasets/000541/sub-20190924-01" #@param {type:"string"}
#@markdown ###Model name and path to model folder:
model_path = "/scratch/th3129/wormID/results" #@param {type:"string"}
#@markdown ---
#@markdown ###Perform validation on a test dataset
do_validation = False #@param {type:"boolean"}
#@markdown ###Path to evaluation data (optional, use if checked above):
eval_source = "./gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/eval/vol/" #@param {type:"string"}
eval_target = "./gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/eval/lab/" #@param {type:"string"}
#@markdown ---
#@markdown ###Training parameters
number_of_epochs = 50 #@param {type:"number"}
#@markdown ###Default advanced parameters
use_default_advanced_parameters = False #@param {type:"boolean"}
#@markdown <font size = 4>If not, please change:

#@markdown <font size = 3>Training parameters:
batch_size =  4 #@param {type:"number"}
learning_rate = 2e-5 #@param {type:"number"}
num_classes = 2 #@param {type:"number"}
weight_decay = 0.01 #@param {type:"number"}
#@markdown <font size = 3>Validation parameters:
validation_frequency = 2 #@param {type:"number"}
#@markdown <font size = 3>SoftNCuts parameters:
intensity_sigma = 1.0 #@param {type:"number"}
spatial_sigma = 4.0 #@param {type:"number"}
ncuts_radius = 2 #@param {type:"number"}
#@markdown <font size = 3>Reconstruction loss:
rec_loss = "MSE" #@param["MSE", "BCE"]
#@markdown <font size = 3>Weighted sum of losses:
n_cuts_weight = 0.5 #@param {type:"number"}
rec_loss_weight = 0.005 #@param {type:"number"}

# **4. Train the network**
---

<font size = 4>Important Reminder: Google Colab imposes a maximum session time to prevent extended GPU usage, such as for data mining. Ensure your training duration stays under 12 hours. If your training is projected to exceed this limit, consider reducing the `number_of_epochs`.

## **4.1. Initialize the config**
---

In [7]:
# @title
train_data_folder = Path(training_source)
results_path = Path(model_path)
results_path.mkdir(exist_ok=True)
eval_image_folder = Path(eval_source)
eval_label_folder = Path(eval_target)

eval_dict = c.create_eval_dataset_dict(
        eval_image_folder,
        eval_label_folder,
    ) if do_validation else None

try:
  import wandb
  WANDB_INSTALLED = True
except ImportError:
  WANDB_INSTALLED = False


train_config = WNetTrainingWorkerConfig(
    device="cuda:0",
    max_epochs=number_of_epochs,
    learning_rate=2e-5,
    validation_interval=2,
    batch_size=4,
    num_workers=2,
    weights_info=WeightsInfo(),
    results_path_folder=str(results_path),
    train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),
    eval_volume_dict=eval_dict,
) if use_default_advanced_parameters else WNetTrainingWorkerConfig(
    device="cuda:0",
    max_epochs=number_of_epochs,
    learning_rate=learning_rate,
    validation_interval=validation_frequency,
    batch_size=batch_size,
    num_workers=2,
    weights_info=WeightsInfo(),
    results_path_folder=str(results_path),
    train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),
    eval_volume_dict=eval_dict,
    # advanced
    num_classes=num_classes,
    weight_decay=weight_decay,
    intensity_sigma=intensity_sigma,
    spatial_sigma=spatial_sigma,
    radius=ncuts_radius,
    reconstruction_loss=rec_loss,
    n_cuts_weight=n_cuts_weight,
    rec_loss_weight=rec_loss_weight,
)
wandb_config = WandBConfig(
    mode="disabled" if not WANDB_INSTALLED else "online",
    save_model_artifact=False,
)

No files found in /scratch/th3129/wormID/datasets/000541/sub-20190924-01


TypeError: object of type 'NoneType' has no len()

In [None]:
import os
os.getcwd()

'/scratch/th3129/wormID/models/CellSeg3d/notebooks'

## **4.2. Start training**
---

In [None]:
# @title
worker = c.get_colab_worker(worker_config=train_config, wandb_config=wandb_config)
for epoch_loss in worker.train():
  continue