<a href="https://colab.research.google.com/github/pycerr-notebooks/blob/main/autosegment_CT_Lung_OARs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

In this tutorial, we will demonstrate how to apply a pre-trained AI model to segment the OARs on a lung CT scan using pyCERR.  

## AI model
* The segmentation model was trained and validated on CT scans used for RT planning. It does not work optimally on diagnostic CTs.
* The trained model is distributed along with python libraries and other dependencies via a conda package.
* The model requires acess to a GPU.

### Running the model

* Conda packge is location: condaEnvPath: `'/home/jupyter/AImodels/CTLungOARincrMRRN/'`
* Inference script location: `os.path.join(condaEnvPath,'CT_LungOAR_incrMRRN/model_wrapper/run_inference_nii.py')`

```python
python run_inference_nii.py <input_nii_directory> <output_nii_directory>
```

# Install pyCERR

pyCERR is used for data import/export and transformation.

In [0]:
!pip install git+https://github.com/cerr/pyCERR.git

# Download pretrained segmentation model



In [1]:
# Location of LungOAR model
modelDir ='/content/pretrainedModel'

# Download packaged environment for the AI model
boxLink = 'https://mskcc.box.com/shared/static/xph6atg73tuxmo26ndrajxcm02h85elk.gz'
saveFileName = 'ct_lung_oar.tar.gz'

!mkdir -p {modelDir}
!wget {boxLink} -O {saveFileName}
!tar xf {saveFileName} -C {modelDir}
!rm {saveFileName}

# Functions for data pre- and post-processing

## Identify the input scan and resize slices to 512x512

In [2]:
import cerr
from cerr.dataclasses import structure
from cerr.contour import rasterseg as rs
from cerr.utils import identifyScan, imageProc

def processInputData(planC):

  # Identify scan index in  planC
  scanIdS = {"imageType": "CT SCAN"}
  matchScanV = identifyScan.getScanNumForIdentifier(scanIdS, planC, False)

  # Extract scan
  scanNum = matchScanV[0]
  scan3M = planC.scan[scanNum].getScanArray()
  mask3M = np.empty((0, 0, 0, 0))

  # Resize scan and import to planC
  inputImgSizeV = np.shape(scan3M)
  gridS = planC.scan[scanNum].getScanXYZVals()
  outputImgSizeV = [512, 512, inputImgSizeV[2]]
  method = 'padorcrop3d'
  procScan3M, __, resizeGridS = imageProc.resizeScanAndMask(scan3M, mask3M, gridS, outputImgSizeV, method)

  return procScan3M, resizeGridS

## Import and refine AI segmentations

In [3]:
# Map output labels to structure names

strToLabelMap = {1:"Lung_Left", 2:"Lung_Right", 3:"Heart", 4:"Esophagus", \
                 5:"Cord", 6:"PBT"}
numLabel = len(strToLabelMap)

In [4]:
#Import label map to CERR
import glob

def postProcAndImportSeg(outputDir,procScanNum,scanNum,planC):
  niiGlob = glob.glob(os.path.join(outputDir,'*.nii.gz'))

  print('Importing ' + niiGlob[0]+'...')
  numStrOrig = len(planC.structure)
  planC = pc.load_nii_structure(niiGlob[0], procScanNum, planC, \
                              labels_dict = strToLabelMap)
  cpyStrNumV = np.arange(numStrOrig,len(planC.structure))
  numComponents = 1
  for label in range(numLabel):
    # Copy to original scan
    planC = structure.copyToScan(cpyStrNumV[label], scanNum, planC)
    origStr = len(planC.structure)-1
    mask3M = rs.getStrMask(origStr,planC)
    # Post-process
    procMask3M = imageProc.getLargestConnComps(mask3M,numComponents)
    strName =  strToLabelMap[label+1]
    planC = pc.import_structure_mask(procMask3M, scanNum, strName, [], planC)
    # Delete original
    del planC.structure[origStr]

  return planC

# Segment OARs

## Define I/O paths

Specify paths to the DICOM input data, desired output directory, and temporary (session) directory used to store intermediate results.

In [5]:
import os
from cerr import plan_container as pc

#Paths to input data and conda env with pre-trained models
inputDicomPath = '/content/sampleData/'  # Replace with path to dataset
outputDicomPath = '/content/AIoutput/'
sessionPath = '/content/temp'
condaEnvPath = '/content/pretrainedModel/'

if not os.path.exists(outputDicomPath):
  os.mkdir(outputDicomPath)

# Path to packaged conda environment
wrapperPath = os.path.join(condaEnvPath,'CT_LungOAR_incrMRRN', \
                          'model_wrapper','run_inference_nii.py')
activateScript = condaEnvPath+'/bin/activate'

## Run AI model

In [6]:
%%capture
import subprocess
import numpy as np
import cerr
from cerr import plan_container as pc
from cerr.ai import createSessionDir as cdir
from cerr.dcm_export import rtstruct_iod

# Loop over pyCERR files
fileList = os.listdir(inputDicomPath)
numFiles = len(fileList)
modality = 'CT'
scanNum = 0

for iFile in range(numFiles):

    dcmDir = os.path.join(inputDicomPath,fileList[iFile])

    # Create session dir to store temporary data
    modInputPath, modOutputPath = cdir.createSessionDir(sessionPath, inputDicomPath)

    # Import DICOM to planC
    planC = pc.load_dcm_dir(dcmDir)

    # Pre-process data
    procScan3M, resizeGridS = processInputData(planC)
    planC = pc.import_scan_array(procScan3M, resizeGridS[0], \
            resizeGridS[1], resizeGridS[2], modality, scanNum, planC)
    procScanNum = len(planC.scan) - 1

    # Export inputs to NIfTI
    scanFilename = os.path.join(modInputPath, f"{fileList[iFile]}_scan_3D.nii.gz")
    planC.scan[procScanNum].save_nii(scanFilename)

    # Apply model
    subprocess.run(f"source {activateScript} && python {wrapperPath} \
                  {modInputPath} {modOutputPath}", \
                  capture_output=False,shell=True,executable="/bin/bash")

    # Import results to planC
    planC = postProcAndImportSeg(modOutputPath,procScanNum,scanNum,planC)

    # Export segmentations to DICOM
    structFileName = fileList[iFile]+'_AI_seg.dcm'
    structFilePath = os.path.join(outputDicomPath,structFileName)
    structNumV = list(np.arange(len(planC.structure)-numLabel,\
                                len(planC.structure)))
    seriesDescription = "AI Generated"
    rtstruct_iod.create(structNumV,structFilePath,planC,seriesDescription)


# Display results

## Display using matplotlib

In [7]:
from cerr.viewer import showMplNb

showMplNb(scanNum, structNumV, planC,\
          windowCenter=-400, windowWidth=2000)

interactive(children=(IntSlider(value=48, description='slcNum', max=144, min=-48), Text(value='axial', descrip…

interactive(children=(IntSlider(value=256, description='slcNum', max=768, min=-256), Text(value='sagittal', de…

interactive(children=(IntSlider(value=256, description='slcNum', max=768, min=-256), Text(value='coronal', des…