<a href="https://colab.research.google.com/github/cerr/pycerr-notebooks/blob/main/auto_register_segment_MR_Pancreas_OARs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

In this tutorial, we will demonstrate how to apply a pre-trained AI model [1] for deformable registration and segmentation of organs at risk (OARs) on longitudinal T2w-MRI scans using pyCERR.


## Requirements
* Python>=3.8
* Applying this model requires access to a GPU.  
  *On Colab* :  `Runtime > Change runtime type > Select GPU `  
  
## I/O  

* **Inputs**: DICOM-format longitundinal T2w MRI of the pancreas and segmentation masks of the liver, bowels (small, large), and duo-stomach on the baseline scan.

* Pre-processing: The earlier scan must be rigidly registered to the later scan, and input segmentations of the OARs listed below must be available on the earlier scan.  

* **Outputs**:  
  [1] DICOM RTStruct-format segmentations of
  * Liver
  * Bowel_Lg
  * Bowel_Sm
  * Duostomach
  
  [2] Deformable vector field (DVF) to deform the earlier (moving) scan to a later (baseline) scan.      
      
      
  Input data should be organized as: one directory of DICOM images per patient.       
    
```    
    Input dir
            |------Pat1  
                      |---Week1
                            |------REG_img1.dcm  
                                   REG_img2.dcm  
                                   ....  
                                   REG_RTSRTUCT  
                      |---Week2
                            |------img1.dcm  
                                   img2.dcm  
                                   ....  
                                   ....  
            |-----Pat2  
                     |---Week1
                            |------REG_img1.dcm  
                                   REG_img2.dcm  
                                   ....  
                                   REG_RTSRTUCT  
                      |---Week2
                            |------img1.dcm  
                                   img2.dcm  
                                   ....  
                                   ....
```


## Installing the model and its dependencies

* Installation is performed using CERR's [***model installer***]( https://github.com/cerr/model_installer).  

* A Conda archive containing dependencies is downloaded to the `conda-pack`   
  sub-directory of a configurable `scriptInstallDir`.  
  By default `condaEnvPath = '/content/MRI_Pancreas_Fullshot_AnatomicCtxShape/conda-pack'`
  
* The inference script is located at   
  `scriptInstallDir = os.path.join(condaEnvPath,'model_wrapper', run_inference_first_to_last_nii.py')`
  
  
## Running the model

```python
!python {wrapperPath} {input_nii_directory} {output_nii_directory}
```
* Data transformations including converting between DICOM and NIfTI formats,  automatic extraction of patient outline, and resizing, are performed using [***pyCERR***](https://github.com/cerr/pyCERR) [2].

## References

1. Jiang, J., Hong, J., Tringale, K., Reyngold, M., Crane, C., Tyagi, N., & Veeraraghavan, H. (2023). Progressively refined deep joint registration segmentation (ProRSeg) of gastrointestinal organs at risk: Application to MRI and cone-beam CT. Medical Physics, 50(8), 4758-4774.

2. Iyer, A., Locastro, E., Apte, A. P., Veeraraghavan, H., & Deasy, J. O. (2021). Portable framework to deploy deep learning segmentation models for medical images. bioRxiv, 2021-03.


## License

By downloading the software you are agreeing to the following terms and conditions as well as to the Terms of Use of CERR software.

**`THE SOFTWARE IS PROVIDED "AS IS" AND CERR DEVELOPMENT TEAM AND ITS COLLABORATORS DO NOT MAKE ANY WARRANTY, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, NOR DO THEY ASSUME ANY LIABILITY OR RESPONSIBILITY FOR THE USE OF THIS SOFTWARE.`**

`This software is for research purposes only and has not been approved for clinical use.`

`Software has not been reviewed or approved by the Food and Drug Administration, and is for non-clinical, IRB-approved Research Use Only. In no event shall data or images generated through the use of the Software be used in the provision of patient care.`
  
`YOU MAY NOT DISTRIBUTE COPIES of this software, or copies of software derived from this software, to others outside your organization without specific prior written permission from the CERR development team except where noted for specific software products.`

`All Technology and technical data delivered under this Agreement are subject to US export control laws and may be subject to export or import regulations in other countries. You agree to comply strictly with all such laws and regulations and acknowledge that you have the responsibility to obtain
such licenses to export, re-export, or import as may be required after delivery to you.`

**`You may publish papers and books using results produced using software provided that you reference the following`**:
  
  1. **AI model** :  https://doi.org/10.1002/mp.16527
  2. **Data processing** :  https://doi.org/10.1101/2021.03.17.435903


In [None]:
import os


# Input data

Define paths to input DICOM directory, desired output directory, and a session directory to store temporary files during model execution.

In [None]:
#I/O paths
inputDicomPath = 'your/path/here'
workDir = '/content/demo'
inputSubDirs = ['Masks']
outputDicomPath = os.path.join(workDir, 'AIoutput')
sessionPath = os.path.join(workDir, 'temp')

if not os.path.exists(workDir):
  os.mkdir(outputDicomPath)

if not os.path.exists(outputDicomPath):
  os.mkdir(outputDicomPath)

if not os.path.exists(sessionPath):
  os.mkdir(sessionPath)

## Download pre-trained model, [inference script](https://github.com/cerr/MRI_Pancreas_Fullshot_AnatomicCtxShape), and dependencies to ***scriptInstallDir***


In [None]:
os.chdir(workDir)
!git clone https://github.com/cerr/model_installer.git
os.chdir(os.path.join(workDir,'model_installer'))

modelOpt = '5'  # MRI_Pancreas_Fullshot_AnatomicCtxShape
pythonOpt = 'C' # Download packaged Conda environment

! source ./installer.sh -m {modelOpt} -d {workDir} -p {pythonOpt}

In [None]:
# Location of inference script
scriptInstallDir = os.path.join(workDir, 'model_installer', 'MRI_Pancreas_Fullshot_AnatomicCtxShape')
scriptPath = os.path.join(scriptInstallDir,
                         'model_wrapper',
                         'run_inference_first_to_last_nii.py')

# Location of activation script for Conda environment
activateScript = os.path.join(scriptInstallDir,'bin','activate')

## Install ***pyCERR***

pyCERR is used for data import/export and pre- and post-processing transformations needed for this model.

In [None]:
!pip install "pyCERR @ git+https://github.com/cerr/pyCERR.git@testing"

# Data processing functions

## Pre-processing  

### `processInputData`: Crop scan and input mask around patient outline and resize to 128 x 192 x 128 voxels

In [None]:
# Map output labels to structure names
structToLabelMap  = {"Liver": 1,
                     "Bowel_Lg": 2,
                     "Bowel_Sm": 3,
                     "Duostomach": 4}

In [None]:
import cerr.plan_container as pc
from cerr.dataclasses import scan as cerrScn
from cerr.dataclasses import structure as cerrStr
from cerr.contour import rasterseg as rs
from cerr.utils import mask
from cerr.utils.image_proc import resizeScanAndMask


def processInputData(scanIdx, outlineStructName, structToLabelMap, planC):
    """Pre-process scan and mask for input to model"""

    #--------------------------------------------------
    #          Extract input label map
    #---------------------------------------------------
    structList = [str.structureName for str in planC.structure]

    if structToLabelMap is not None:

        segStructList = list(structToLabelMap.keys())
        labels = list(structToLabelMap.values())
        structNumV = np.zeros(len(labels),dtype=int)

        for numLabel in range(len(labels)):
            # Get structure indices in planC
            structName = segStructList[numLabel]
            matchIdxV = cerrStr.getMatchingIndex(structName,
                                                 structList,
                                                 matchCriteria='exact')
            # Ensure association with input scan
            assocScansV = [int(cerrScn.getScanNumFromUID(planC.structure[idx].assocScanUID, planC))
                           for idx in matchIdxV]
            try:
                validMatchIdx = matchIdxV[assocScansV.index(scanIdx)]
            except:
              raise ValueError("Missing structure ", structName)
            structNumV[numLabel] = validMatchIdx

        # Extract label map
        maskList = cerrStr.getMaskList(structNumV,
                                      planC,
                                      labelDict=structToLabelMap)
        mask4M = np.array(maskList)
        mask4M = np.moveaxis(mask4M, 0, -1)# nRows x nCols x nSlc x nlabels
    else:
        mask4M = None

    #--------------------------------------------------
    #          Process input scan
    #---------------------------------------------------
    modality = 'MR'
    scan3M = planC.scan[scanIdx].getScanArray()
    scanSizeV = np.shape(scan3M)

    # 1. Crop to  patient outline

    ## Extract outline
    outlineIdx = structList.index(outlineStructName) \
                if outlineStructName in structList else None

    if outlineIdx is None:
        # Generate outline mask
        threshold = 0.03
        outline3M = mask.getPatientOutline(scan3M,
                                           threshold,
                                           normFlag=True)

        planC = pc.importStructureMask(outline3M,
                                       scanIdx,
                                       outlineStructName,
                                       planC,
                                       None)
    else:
        # Load outline mask
        outline3M = rs.getStrMask(outlineIdx, planC)

    ## Crop scan and mask to pt outline
    cropMask4M = None
    minr, maxr, minc, maxc, mins, maxs, _ = mask.computeBoundingBox(outline3M)
    cropScan3M = scan3M[minr:maxr+1, minc:maxc+1, mins:maxs+1]
    cropScanSizeV = np.shape(cropScan3M)
    if mask4M is not None:
        cropMask4M = mask4M[minr:maxr+1, minc:maxc+1, mins:maxs+1, :]

    ## Calc. cropped scan grid
    gridS = planC.scan[scanIdx].getScanXYZVals()
    cropGridS = (gridS[0][minc:maxc+1],
                 gridS[1][minr:maxr+1],
                 gridS[2][mins:maxs+1])

    # 2. Resize scan
    ## Crop scan in-plane
    outputImgSizeV = [128, 192, cropScanSizeV[2]]
    method = 'bicubic'
    procScan3M, procMask4M, resizeGridS = resizeScanAndMask(cropScan3M,
                                          cropMask4M,
                                          cropGridS,
                                          outputImgSizeV,
                                          method)
    ## Pad scan along slices
    resizeScanShape = procScan3M.shape
    outputImgSizeV = [resizeScanShape[0], resizeScanShape[1], 128]
    method = 'padslices'
    procPadScan3M, procPadMask4M, padGridS = resizeScanAndMask(procScan3M, \
                                                               procMask4M,\
                                                               resizeGridS, \
                                                               outputImgSizeV,\
                                                               method)

    #--------------------------------------------------
    #    Import processed scan & mask to planC
    #---------------------------------------------------
    ## Import scan
    planC = pc.importScanArray(procPadScan3M,
                                 padGridS[0],
                                 padGridS[1],
                                 padGridS[2],
                                 modality,
                                 scanIdx,
                                 planC)
    processedScanIdx = len(planC.scan) - 1

    ## Import mask
    processedStrIdxV = []
    if procPadMask4M is not None:
        for structIndex in range(len(segStructList)):
            structName = segStructList[structIndex]
            procPadMask3M = procPadMask4M[:, :, :, structIndex]
            planC = pc.importStructureMask(procPadMask3M,
                                           processedScanIdx,
                                           structName,
                                           planC,
                                           None)
            processedStrIdxV.append(len(planC.structure) - 1 )
    else:
        processedStrIdxV = None

    return processedScanIdx, processedStrIdxV, padGridS


## **Post-processing**

### `postProcAndImportSeg`: Read label maps, undo pre-processing transformations, and retain only the largest connected component of the resulting mask.

In [None]:
import glob
import SimpleITK as sitk

def postProcAndImportSeg(modOutputPath, baseScanIdx, outlineStructName,
                         structToLabelMap, inputGridS, planC):
    """ Import auto-segmentations to planC"""

    #--------------------------------------------------
    #              Read AI-generated mask
    #---------------------------------------------------
    niiGlob = glob.glob(os.path.join(modOutputPath, '*.nii.gz'))
    print('Importing ' + niiGlob[0] + '...')
    outputMask = sitk.ReadImage(niiGlob[0])
    outputMask3M = sitk.GetArrayFromImage(outputMask)
    numStrOrig = len(planC.structure)
    numAIStrs = len(structToLabelMap)

    #--------------------------------------------------
    #      Undo pre-processing transformations
    #---------------------------------------------------
    ## Extract extents of patient outline
    structList = [struct.structureName for struct in planC.structure]
    outlineIdx = structList.index(outlineStructName)
    outline3M = rs.getStrMask(outlineIdx, planC)
    minr, maxr, minc, maxc, mins, maxs, _ = mask.computeBoundingBox(outline3M)

    nSlc = maxs-mins+1
    resizedDimsV = [128, 192, nSlc]

    ## Undo padding
    outputScan3M = None
    method = 'unpadslices'
    _, unPadMask4M, unPadGridS = resizeScanAndMask(outputScan3M,
                                                   outputMask3M,
                                                   inputGridS,
                                                   resizedDimsV,
                                                   method)
    ## Undo resizing
    outputImgSizeV = [maxr-minr+1, maxc-minc+1, nSlc]
    method = 'bicubic'
    _, resizeMask4M, resizeGridS = resizeScanAndMask(outputScan3M,
                                                     unPadMask4M,
                                                     unPadGridS,
                                                     outputImgSizeV,
                                                     method)

    ## Undo cropping
    baseImgSizeV = list(planC.scan[baseScanIdx].getScanSize())
    fullMask4M = np.zeros(baseImgSizeV + [numAIStrs])
    fullMask4M[minr:maxr+1, minc:maxc+1, mins:maxs+1, :] = resizeMask4M

    #--------------------------------------------------
    #             Import results to planC
    #---------------------------------------------------
    numComponents = 1
    structNames = list(structToLabelMap.keys())
    for numLabel in range(numAIStrs):
        binMask = fullMask4M[:, :, :, numLabel]

        structName = 'AI_' + structNames[numLabel]
        planC = cerrStr.importStructureMask(binMask,
                                            baseScanIdx,
                                            structName,
                                            planC,
                                            None)
        # Post-process and replace input structure in planC
        structNumV = len(planC.structure) - 1
        importMask3M, planC = cerrStr.getLargestConnComps(structNumV,
                                                          numComponents,
                                                          planC,
                                                          saveFlag=True,
                                                          replaceFlag=True,
                                                          procSructName=structName)

    return planC

# Segment OARs

## Apply AI model  to all longitudinal MR datasets  

### located in ***inputDicomPath*** and store auto-segmentation results to ***outputDicomPath***


In [None]:
import subprocess
import numpy as np

from cerr import plan_container as pc
from cerr.dataclasses import scan as cerrScn
from cerr.dcm_export import rtstruct_iod
from cerr.utils.ai_pipeline import createSessionDir, getScanNumFromIdentifier

fileList = os.listdir(inputDicomPath)
numFiles = len(fileList)
scanNum = 0
modality = 'MR'
outlineStructName = 'patient_outline'

# Loop over pyCERR files
for iFile in range(numFiles):

    # Create session dir to store temporary data
    modInputPath, modOutputPath = createSessionDir(sessionPath,
                                                   inputDicomPath,
                                                   inputSubDirs)

    # Import DICOM scan to planC
    inputFile = fileList[iFile]
    dcmDir = os.path.join(inputDicomPath, inputFile)
    planC = pc.loadDcmDir(dcmDir)
    numExistingStructs = len(planC.structure)
    #--------------------
    # Pre-process data
    #---------------------

    #1. Base scan
    identifier = {"seriesDate": "last"}
    baseScanIdx = int(getScanNumFromIdentifier(identifier, planC)[0])
    exportBaseLabelMap = None
    procBaseScanIdx, __, procBaseGridS = processInputData(baseScanIdx,
                                                          outlineStructName,
                                                          exportBaseLabelMap,
                                                          planC)

    #2. Moving scan
    identifier = {"seriesDate": "first"}
    movScanIdx = int(getScanNumFromIdentifier(identifier, planC)[0])
    exportMovLabelMap = structToLabelMap
    procMovScanIdx, procMovStrIdxV, procMovGridS = \
                                    processInputData(movScanIdx,
                                                     outlineStructName,
                                                     exportMovLabelMap,
                                                     planC)

    # Export processed inputs to NIfTI
    baseScanFile = os.path.join(modInputPath,
                            f"{inputFile}_MR SCAN_last_scan_3D.nii.gz")
    movScanFile = os.path.join(modInputPath,
                           f"{inputFile}_MR SCAN_first_scan_3D.nii.gz")
    movMaskFile = os.path.join(modInputPath, inputSubDirs[0], \
                           f"{inputFile}_MR SCAN_first_4D.nii.gz")
    planC.scan[procBaseScanIdx].saveNii(baseScanFile)
    planC.scan[procMovScanIdx].saveNii(movScanFile)
    pc.saveNiiStructure(movMaskFile,
                        exportMovLabelMap,
                        planC,
                        strNumV=procMovStrIdxV,
                        dim=4)

    # Apply pretrained AI model
    subprocess.run(f"source {activateScript} && python {scriptPath} \
                  {modInputPath} {modOutputPath}", \
                  capture_output=False,shell=True,executable="/bin/bash")


    #Import segmentations to planC
    planC = postProcAndImportSeg(modOutputPath,
                                 baseScanIdx,
                                 outlineStructName,
                                 structToLabelMap,
                                 procMovGridS,
                                 planC)

    newNumStructs = len(planC.structure)


    # Export segmentations to DICOM
    structFileName = fileList[iFile] + '_AI_seg.dcm'
    structFilePath = os.path.join(outputDicomPath, structFileName)
    structNumV = np.arange(numExistingStructs+1, newNumStructs)
    indOrigV = np.array([cerrScn.getScanNumFromUID(\
                         planC.structure[structNum].assocScanUID, planC)\
                         for structNum in structNumV], dtype=int)
    structsToExportV = structNumV[indOrigV == baseScanIdx]
    seriesDescription = "AI Generated"
    exportOpts = {'seriesDescription': seriesDescription}
    rtstruct_iod.create(structsToExportV,
                        structFilePath,
                        planC,
                        exportOpts)

# Display results

## Overlay AI segmentations on scan for visualization using ***Matplotlib***

Note: This example displays the last segmented dataset by default.  
Load the appropriate pyCERR archive to `planC` to view results for desired dataset.

In [None]:
from cerr.viewer import showMplNb

dispStructsV = list(structsToExportV[1:])

showMplNb(planC=planC, scanNum=baseScanIdx,
          structNums=dispStructsV,
          windowCenter=41100, windowWidth=80300)