<a href="https://colab.research.google.com/github/cerr/pycerr-notebooks/blob/main/autosegment_CT_Heart_OARs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

In this tutorial, we will demonstrate how to apply a pre-trained AI model to segment the Heart sub-structures on a lung CT scan using pyCERR.

## AI model
* The segmentation model was trained and validated on CT scans used for RT planning. It does not work optimally on diagnostic CTs or scans in positions other than Head First Supine.
* The trained model is distributed along with python libraries and other dependencies via a conda package.
* The model requires acess to a GPU.

### Running the model

* Conda packge is location: condaEnvDir
* Inference script location: wrapperPath

```python
!python {wrapperPath} {input_nii_directory} {output_nii_directory}
```

# Install pyCERR

pyCERR is used for pre and post-processing of DICOM.

In [None]:
!pip install git+https://github.com/cerr/pyCERR.git@testing

# Download pretrained segmentation model



### Download the pre-packaged Anaconda environment

In [None]:
# Location of LungOAR conda environment
heartCondaEnvDir ='/content/pretrainedHeartModel'

# Download packaged environment for the Heart segmentation model
boxLink = 'https://mskcc.box.com/shared/static/o524xsg5s91q0frwka10fvsf82395sgg.gz'
saveFileName = 'ct_heart_oar.tar.gz'

!mkdir -p {heartCondaEnvDir}
!wget {boxLink} -O {saveFileName}
!tar xf {saveFileName} -C {heartCondaEnvDir}
!rm {saveFileName}

# Path to conda environment activate script
heartEnvActivateScript = heartCondaEnvDir + '/bin/activate'

### Download the inference script and model weights

In [None]:
import os
wrapperInstallDir = '/content/CT_cardiac_structures_deeplab'
!git clone https://github.com/cerr/CT_cardiac_structures_deeplab.git {wrapperInstallDir}
wrapperPath = os.path.join(wrapperInstallDir, 'model_wrapper','runSegmentation.py')
modelWeightZipPath = os.path.join(wrapperInstallDir,'model.gz')
modelWeightPath = os.path.join(wrapperInstallDir,'model')
!wget -O {modelWeightZipPath} -L https://mskcc.box.com/shared/static/o524xsg5s91q0frwka10fvsf82395sgg.gz
!tar xf {modelWeightZipPath} -C {wrapperInstallDir}
!rm {modelWeightZipPath}

# Functions for data pre- and post-processing

## Crop scan to Lung extents

In [None]:
from cerr.dataclasses import structure as cerrStr
from cerr.contour import rasterseg as rs
from cerr.utils import identifyScan, imageProc
from cerr.utils import bbox
import numpy as np

def processInputData(scanNum, planC, lungNameList=['Lung_total', 'Lung_L', 'Lung_R']):

    if isinstance(lungNameList, str):
        lungNameList = [lungNameList]
        
    # Extract scanArray
    scan3M = planC.scan[scanNum].getScanArray()
    mask3M = np.zeros(scan3M.shape)

    # List of Structure names
    strNames = [s.structureName for s in planC.structure]
    numOrigStructs = len(strNames)
    
    # Get total lung mask
    for lungName in lungNameList:
        lungInd = cerrStr.getMatchingIndex(lungName, strNames, 'exact')    
        if len(lungInd) > 0:
            # Get lung extents
            mask3M = mask3M & rs.getStrMask(lungInd[0], planC)
    
    if not np.any(mask3M):
        raise Exception('Lung contour name did not match any structures in planC')

    # Create cropped scan
    rmin,rmax,cmin,cmax,smin,smax,_ = bbox.compute_boundingbox(mask3M)
    x,y,z = planC.scan[0].getScanXYZVals()
    xCropV = x[cmin:cmax]
    yCropV = y[rmin:rmax]
    zCropV = z[smin:smax]
    scan3M = planC.scan[0].getScanArray()
    scanCrop3M = scan3M[rmin:rmax,cmin:cmax,smin:smax]

    return scanCrop3M, (xCropV, yCropV, zCropV)
    

## Import and refine AI segmentations

In [None]:
#Import label map to CERR
import glob
from cerr import plan_container as pc

atriaLabelDict = {1: 'DL_Atria'}
heartSubSegDict = {2: 'AORTA', 3: 'DL_LA',
                   4: 'DL_LV', 5: 'DL_RA',
                   6: 'DL_RV', 7: 'DL_IVC',
                   8: 'DL_SVC', 9: 'DL_PA'}
heartSegDict = {1: 'DL_heart'}
periLabelDict = {1: 'DL_Pericardium'}
ventriLabelDict = {1: 'DL_Ventricles'}

def postProcAndImportSeg(outputDir,procScanNum,scanNum,planC):
    niiGlob = glob.glob(os.path.join(outputDir,'*.nii.gz'))
    for segFile in niiGlob:
        print('Importing ' + niiGlob[0]+'...')
        # Get segFile name
        if 'heart.nii.gz' in segFile:
            strToLabelMap = heartSubSegDict
        elif 'heartStructure.nii.gz' in segFile:
            strToLabelMap = heartSegDict
        elif 'atria.nii.gz' in segFile:
            strToLabelMap = atriaLabelDict
        elif 'pericardium.nii.gz' in segFile:
            strToLabelMap = periLabelDict
        elif 'ventricles.nii.gz' in segFile:
            strToLabelMap = ventriLabelDict
        numLabel = len(strToLabelMap)
        numStrOrig = len(planC.structure)
        planC = pc.load_nii_structure(segFile, procScanNum, planC, \
                                  labels_dict = strToLabelMap)
        cpyStrNumV = np.arange(numStrOrig,len(planC.structure))
        numComponents = 1
        for label in range(numLabel):
            # Copy to original scan
            planC = structure.copyToScan(cpyStrNumV[label], scanNum, planC)
            origStr = len(planC.structure)-1
            mask3M = rs.getStrMask(origStr,planC)
            # Post-process
            procMask3M = imageProc.getLargestConnComps(mask3M,numComponents)
            strName =  strToLabelMap[label+1]
            planC = pc.import_structure_mask(procMask3M, scanNum, strName, [], planC)
            # Delete original
            del planC.structure[origStr]

  return planC

# Segment OARs

## Define I/O paths

Specify paths to the DICOM input data, desired output directory, and temporary (session) directory used to store intermediate results.

In [None]:
import os

#Paths to input data and conda env with pre-trained models
inputDicomPath = '/content/sampleData/'  # Replace with path to dataset
outputDicomPath = '/content/AIoutput/'
sessionPath = '/content/temp'

if not os.path.exists(outputDicomPath):
  os.mkdir(outputDicomPath)


## Run AI model

In [None]:
%%capture
import subprocess
import numpy as np
import cerr
from cerr import plan_container as pc
from cerr.ai import createSessionDir as cdir
from cerr.dcm_export import rtstruct_iod

# Loop over pyCERR files
fileList = os.listdir(inputDicomPath)
numFiles = len(fileList)
modality = 'CT SCAN'

for iFile in range(numFiles):

    dcmDir = os.path.join(inputDicomPath,fileList[iFile])

    # Create session dir to store temporary data
    modInputPath, modOutputPath = cdir.createSessionDir(sessionPath, inputDicomPath)

    # Import DICOM to planC
    planC = pc.load_dcm_dir(dcmDir)

    # Identify scan index in  planC
    scanIdS = {"imageType": modality}
    matchScanV = identifyScan.getScanNumForIdentifier(scanIdS, planC, False)
    scanNum = matchScanV[0]
    
    # Pre-process data
    procScan3M, resizeGridS = processInputData(planC)
    planC = pc.import_scan_array(procScan3M, resizeGridS[0], \
            resizeGridS[1], resizeGridS[2], modality, scanNum, planC)
    procScanNum = len(planC.scan) - 1

    # Export inputs to NIfTI
    scanFilename = os.path.join(modInputPath, f"{fileList[iFile]}_scan_3D.nii.gz")
    planC.scan[procScanNum].save_nii(scanFilename)

    # Apply model
    subprocess.run(f"source {heartEnvActivateScript} && python {wrapperPath} \
                  {modInputPath} {modOutputPath}", \
                  capture_output=False,shell=True,executable="/bin/bash")

    # Import results to planC
    planC = postProcAndImportSeg(modOutputPath,procScanNum,scanNum,planC)

    # Export segmentations to DICOM
    structFileName = fileList[iFile] + '_AI_seg.dcm'
    structFilePath = os.path.join(outputDicomPath,structFileName)
    structNumV = list(np.arange(len(planC.structure)-numLabel,\
                                len(planC.structure)))
    seriesDescription = "AI Generated"
    exportOpts = {'seriesDescription': seriesDescription}
    rtstruct_iod.create(structNumV,structFilePath,planC,exportOpts)


# Display results

## Display using matplotlib

In [None]:
from cerr.viewer import showMplNb

showMplNb(scanNum, structNumV, planC,\
          windowCenter=-400, windowWidth=2000)