In [3]:
#requires you to pip install nnunetv2 and some others
#for install details see: https://github.com/MIC-DKFZ/nnUNet
import nnunetv2
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
import os,sys
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm

#sys.path.append('path_to_dir_w_scripts') #not required if in the same dir
sys.path.append('/home/hvv/Documents/git_repo') #not required if in the same dir
from nnunet_utils.utils import np2sitk, set_env_nnunet, write_envlines_nnunet, assign_to_gpus
from nnunet_utils.preprocess import write_as_nnunet, nnunet_directory_structure, preprocess_data
from nnunet_utils.run import train_single_model, nnunet_train_shell

#root in what folder your nnunet data is stored
root = '/home/hvv/Documents/nnunet'
datano = '512' #this is an arbitrary number you can choose --> should not be the same as other studies
project_name = 'nameyourproject'
task = 'Task{}_nameyourproject'.format(datano) #this is also something you choose
datasetID = 'Dataset{}_nameyourproject'.format(datano)

In [None]:
#iterate over your dataset
#create a training images and label folder

#write your own script here loading an IMG and GT every iteration
#something like:

p_dir = 'where_you_store_your_nnunet_data_and_labels'

p_data = 'path_to_source_img_lbl' #subfolders with IDs and scan and gt files
for ID in os.listdir(p_data):
    pid = os.path.join(p_data,ID)
    
    p_img = os.path.join(pid,'scan.nii.gz')
    p_gt = os.path.join(pid,'gt.nii.gz')
    
    IMG = sitk.ReadImage(p_img)
    GT = sitk.ReadImage(p_gt)
    #this example should be applied to all your training images-labels
    #you can also use this to preprocess your test set
    #this is however not required
    write_as_nnunet(IMG, GT, p_dir, ID)
#IMG: sitk.Image with the CT/MR scan 
#GT: sitk Image with corresponding ground truth segmentations
#p_dir: where the imagesTr and labelsTr should be stored
#ID: ID number (including dataset name) for identification of IMG-GT pairs

#sanity check to see if all images have labels
root_images = os.path.join(p_dir,'imagesTr')
root_gt = os.path.join(p_dir, 'labelsTr')
img_lbl_paircount(root_images, root_gt)

In [2]:
#this creates the nnunet directory structure inside the root folder
nnunet_directory_structure(root,version=2)
#make sure your data imagesTr and labelsTr folders are pasted in nnUnet_raw folder

In [None]:
#now the scans have to be preprocessed for training
#this is something specifically required by nnUnet
#this may take a while, if it fails run again
preprocess_data(root, 
                datano=datano,
                datasetID=datasetID, #or task name in old version
                dataset_name=project_name,
                modalities=['BL_MR_FLAIR'] #should be a list representing each input channel --> important: should include MR or CT
               )

In [None]:
#there are two options to instantiate training models
#1) one-by-one: 
#train models consecutively for each fold --> run this manually 5 times
train_single_model(gpu=0, #each pc with a single gpu has number 0, selecting another gpu on a server is possible
                   datasetID=datasetID, #defined above
                   resolution='3d_fullres', #can select nnUnet config: ['2d','3d_fullres','3d_lowres', '3d_cascade_fullres'] 
                   fold=0, #start with the first fold (number 0)
                  )


#2) parallel across gpus: 
#2a) Create mapping: which GPU does what
#Assign jobs to gpus: this is an equal distribution script
#it can be wise to first check gpu availability 
#and then make your own dictionary with distribution dictionary
#returns a dictionary with per entry:
# gpu_number:[job1, job2] 
#where each job:
#(resolution, fold_number)
gpu_dct = assign_to_gpus(num_gpus, #total number of GPUs available OR a list of available GPU numbers
                           num_folds, #number of folds to train (default=5)
                           resolutions #list of resolutions, any from ['2d','3d_fullres','3d_lowres', '3d_cascade_fullres'] 
                                )


#2b) Create shell script
#create a train_job.sh shell script to run multiple folds at the same time
#the shell script manages parallel computation across gpus
nnunet_train_shell(datasetID=datasetID, #defined above
                    root=root,#defined above
                    conda_env='/path/to/miniconda3/envs/nnunetv2', #path to your environment
                    gpu_res_fold_dct=gpu_dct, #is dictionary mapping resolutions, folds and gpus (see above)
                    version=2)

#2c) Run shell script
#Last thing: run the shell script on the server
#ssh to server, cd to nnunet folder then: bash train_job.sh
#to make sure the server stays running when you close your pc
#use tmux: https://tmuxcheatsheet.com/ and https://hamvocke.com/blog/a-quick-and-easy-guide-to-tmux/

In [13]:
#Inference: to predict segmentations using a trained model
#After your model is trained these scripts can be used on new cases
#there are two ways
#1) Run in the python script line by line
from nnunet_utils.infv2 import init_predictor, nnunetv2_get_props, nnunetv2_predict

p_model = '/media/hvv/ec2480e5-6c18-468c-b971-5271432b386d/hvv/graph_age_data/mra_nnunet/train_data/MRA_vesselseg/nnUNet_trained_models/Dataset506_MRAvseg/nnUNetTrainer__nnUNetPlans__3d_fullres'
#predictor = init_predictor(p_model)


#2) Create a file with all input images similar to the imagesTr (but now imagesTs)
#and run it in a batch at once


In [None]:

p_data = 'your_test_set_folder'

for ID in tqdm(os.listdir(p_data)):
    pid = os.path.join(p_crisp,ID)
    #input file
    file = os.path.join(pid,'scan.nii.gz')
    
    #output nifti segmentation and also probability output as npy
    p_vseg_out = os.path.join(pid,'vesselseg.nii.gz')
    p_npy_vseg = os.path.join(pid,'vesselseg')
    
    #sanity check to not run the same stuff twice
    if os.path.exists(p_vseg_out) and os.path.exists(p_npy_vseg+'.npy'):
        continue
    #running this for loop can take long
    #so a try-except to prevent stopping somewhere in the middle
    try:
        mra = sitk.ReadImage(file)
        props = nnunetv2_get_props(mra)
        mra_inp = np.expand_dims(sitk.GetArrayFromImage(mra),0)
        seg = nnunetv2_predict(mra_inp,props,predictor, return_probabilities=True)

        sitk.WriteImage(np2sitk(seg[0],mra),p_vseg_out)

        np.save(p_npy_vseg,seg[1])
    except:
        continue