# End to end fine tune of nnUNet with TotalSegmentator

## 0. Setting up.
- Set up all the environment variables needed for nnUNet (in the following cells)
- Be sure to have saved the dataset in the right format (.nii.gz) and with the right names (see [nnUNet documentation](https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md))

In [1]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
import shutil
import subprocess


from nnunetv2.experiment_planning.plans_for_pretraining.move_plans_between_datasets import move_plans_between_datasets
from totalsegmentator.config import setup_nnunet, setup_totalseg
from totalsegmentator.libs import download_pretrained_weights

In [2]:
setup_nnunet()
setup_totalseg()

# environment variables
os.environ['nnUNet_raw'] = '/home/matteo.fusconi/TOTALSEGMENTATOR/nnUNet_raw'
os.environ['nnUNet_preprocessed'] = '/home/matteo.fusconi/TOTALSEGMENTATOR/nnUNet_preprocessed'
os.environ['nnUNet_results'] = '/home/matteo.fusconi/TOTALSEGMENTATOR/nnUNet_results'

os.environ['TotalSegmentator'] = 'Dataset294_TotalSegmentator_part4_muscles_1559subj' # Totalsegmentator dummy dataset

os.environ['femur_left'] = "Dataset001_Femur_left"# '7'
os.environ['femur_right'] = "Dataset002_Femur_right"# '8'

os.environ['TOTALSEG_HOME_DIR'] = os.path.join(os.path.expanduser("~"), '.totalsegmentator')
os.environ['PATH_TO_CHECKPOINT'] = os.path.join(os.environ.get('TOTALSEG_HOME_DIR'), 
                                                'nnunet', 'results', 'Dataset294_TotalSegmentator_part4_muscles_1559subj', 
                                                'nnUNetTrainerNoMirroring__nnUNetPlans__3d_fullres', 'fold_0', "checkpoint_final.pth")
os.environ['TOTALSEG_MODEL_DIR'] = os.path.join(os.environ.get('TOTALSEG_HOME_DIR'), 
                                                'nnunet', 'results', 'Dataset294_TotalSegmentator_part4_muscles_1559subj', 
                                                'nnUNetTrainerNoMirroring__nnUNetPlans__3d_fullres')

os.environ['data_trainingImages'] = "imagesTr"
os.environ['data_trainingLabels'] = "labelsTr"

classes = [os.environ.get('femur_left'),
            os.environ.get('femur_right'),
            os.environ.get('TotalSegmentator')]

raw_data = [os.path.join(os.environ.get('nnUNet_raw'), x, os.environ.get('data_trainingImages')) for x in classes]
gt_labels = [os.path.join(os.environ.get('nnUNet_raw'), x, os.environ.get('data_trainingLabels')) for x in classes]

task_id = 294
download_pretrained_weights(task_id)


## 1. Preprocessing
- retrieve the plans from Total Segmentator model (it is in the .totalsementator folder)
- preprocess your target dataset (it creates the dataset fingerprint and preprocess the instances in the nnUNet preprocessed folder)
- **move the plans** from old totalsegmentator dataset, to the new one

In [6]:
if not os.path.exists(os.path.join(os.environ.get('nnUNet_preprocessed'), os.environ.get('TotalSegmentator'))):
    Path(os.path.join(os.environ.get('nnUNet_preprocessed'), os.environ.get('TotalSegmentator'))).mkdir(parents= True, exist_ok= True)

shutil.copy(src= os.path.join(os.environ.get('TOTALSEG_MODEL_DIR'), 'plans.json'), 
            dst= os.path.join(os.environ.get('nnUNet_preprocessed'), os.environ.get('TotalSegmentator'), 'nnUNetPlans.json'))

'/home/matteo.fusconi/TOTALSEGMENTATOR/nnUNet_preprocessed/Dataset294_TotalSegmentator_part4_muscles_1559subj/dataset_fingerprint.json'

In [5]:
subprocess.check_call(["nnUNetv2_plan_and_preprocess", 
                       "-d", "4",
                       "-c", "3d_fullres"])


Fingerprint extraction...
Dataset003_Femur
Using <class 'nnunetv2.imageio.simpleitk_reader_writer.SimpleITKIO'> as reader/writer


100%|██████████| 40/40 [00:06<00:00,  5.86it/s]


Experiment planning...

############################
INFO: You are using the old nnU-Net default planner. We have updated our recommendations. Please consider using those instead! Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md
############################

Attempting to find 3d_lowres config. 
Current spacing: [2.         0.76445362 0.76445362]. 
Current patch size: (np.int64(112), np.int64(128), np.int64(128)). 
Current median shape: [248.         311.16504854 248.54368932]
Attempting to find 3d_lowres config. 
Current spacing: [2.         0.78738722 0.78738722]. 
Current patch size: (np.int64(128), np.int64(128), np.int64(112)). 
Current median shape: [248.         302.10198888 241.30455274]
Attempting to find 3d_lowres config. 
Current spacing: [2.         0.81100884 0.81100884]. 
Current patch size: (np.int64(128), np.int64(128), np.int64(112)). 
Current median shape: [248.         293.30290182 234.27626479]
Attempting to find 3d_lowr

100%|██████████| 40/40 [01:41<00:00,  2.53s/it]


'subprocess.check_call(["nnUNetv2_plan_and_preprocess", \n                       "-d", "2",\n                       "-c", "3d_fullres"])'

In [14]:
PRETRAINING_DATASET = 294
TARGET_DATASET = 4
PRETRAINING_PLANS_IDENTIFIER = "nnUNetPlans"
TARGET_PLANS_IDENTIFIER = "totseg_nnUNetPlans"

subprocess.check_call(["nnUNetv2_move_plans_between_datasets",
                        "-s", str(PRETRAINING_DATASET),
                        "-t", str(TARGET_DATASET),
                        "-sp", str(PRETRAINING_PLANS_IDENTIFIER), 
                        "-tp", str(TARGET_PLANS_IDENTIFIER)])


Dataset294_TotalSegmentator_part4_muscles_1559subj


0

## 2. fine tuning
- Manually (TODO) rename the folder `nnUNetPlans_3d_fullres` into `totseg_nnUNetPlans_3d_fullres`
- Run training

In [None]:
FOLD = "1"
TARGET_DATASET = "3"
CONFIG = "3d_fullres"
PATH_TO_CHECKPOINT = os.environ["PATH_TO_CHECKPOINT"]
TRAINER = "nnUNetTrainer_20epochs"

command = ["nnUNetv2_train", TARGET_DATASET, CONFIG, FOLD, 
           "-pretrained_weights", PATH_TO_CHECKPOINT, "-p", "totseg_nnUNetPlans", "-tr", TRAINER]
subprocess.check_call(command)