<a href="https://colab.research.google.com/github/cerr/pycerr-notebooks/blob/main/test_ibsi2_compatibility.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Introduction
This notebook demonstrates image registration using Advanced Normalization Tools (ANTsPy).  

## Define registration function ***registerScansANTS***

In [None]:
import os
import ants
from cerr import plan_container as pc
from cerr.contour.rasterseg import getStrMask
from surface_distance import compute_surface_distances, compute_dice_coefficient, compute_robust_hausdorff

def registerScansANTS(planC, transformType, baseScanFile, movScanFile, movDoseFileList, baseStrNumV, movStructNumV, roiDict, tmpDir):

    # Identify fixed and moving files    
    movScanUID = planC.scan[movScanNum].scanUID
    movDoseNumV = [d for d in range(len(planC.dose)) if planC.dose[d].assocScanUID == movScanUID]

    # Read fixed and moving images
    imgAntsBase = ants.image_read(baseScanFile)
    imgAntsMov = ants.image_read(movScanFile)
    movScanName = os.path.basename(movScanFile)
    warpedScanFileName = os.path.join(tmpDir, 'warped_' + movScanName)

    # Register images
    txLoc = os.path.join(tmpDir,'ants_reg')
    regFile = os.path.join(tmpDir, 'ants_regComposite.h5')
    loadFromFile = 1
    if not os.path.isfile(regFile):                  
            loadFromFile = 0
            result = ants.registration(imgAntsBase, imgAntsMov, transformType = 'SyN',
                                       write_composite_transform = True, outprefix = txLoc)
            warpedScanImage = ants.apply_transforms(fixed=imgAntsBase, moving=imgAntsMov,
                                                    transformlist=result['fwdtransforms'])
            warpedScanImage.image_write(warpedScanFileName)
    else:
        if not os.path.isfile(warpedScanFileName):
            warpedScanImage = ants.apply_transforms(fixed=imgAntsBase, moving=imgAntsMov,
                                                    transformlist=regFile)
            warpedScanImage.image_write(warpedScanFileName)
            
    planC = pc.loadNiiScan(warpedScanFileName, 'CT SCAN', '', planC)
    warpedScanNum = len(planC.scan) - 1

    # Warp associated ROIs
    movStrFile = os.path.join(tmpDir, 'mov_structs.nii.gz')
    pc.saveNiiStructure(movStrFile, roiDict, planC, strNumV=movStructNumV, dim=3)
    warpedStrFileName = os.path.join(tmpDir, 'warped_movscan' + str(movScanNum) + '_structs.nii.gz')

    imgAntsMovStr = ants.image_read(movStrFile)
    origNumStrs = len(planC.structure)
    if not loadFromFile:
            warpedStrImage = ants.apply_transforms(fixed=imgAntsBase, moving=imgAntsMovStr,
                                                   transformlist=result['fwdtransforms'],
                                                   defaultvalue=0, singleprecision=True)
            warpedStrImage.image_write(warpedStrFileName)
    else:
        if not os.path.isfile(warpedStrFileName):
                warpedStrImage = ants.apply_transforms(fixed=imgAntsBase, moving=imgAntsMovStr,
                                                       transformlist=regFile,
                                                       defaultvalue=0, singleprecision=True)
            warpedStrImage.image_write(warpedStrFileName)
    flipDict = dict((v,k) for k,v in roiDict.items())
    planC = pc.loadNiiStructure(warpedStrFileName, warpedScanNum, planC, flipDict)
    warpedStrNumV = np.arange(origNumStrs, len(planC.structure))

    # Warp doses
    warpedDoseNumV = []
    for fileNum in range(len(movDoseFileList)):
        doseFileName = os.path.basename(movDoseFileList[fileNum])
        doseNum = movDoseNumV[fileNum]
        warpedDoseFileName = os.path.join(tmpDir, 'warped_' + doseFileName)

        if not os.path.isfile(warpedDoseFileName):
                imgAntsMovDose = ants.image_read(os.path.join(tmpDir, movDoseFileList[fileNum]))
                if loadFromFile:
                    warpedDoseImage = ants.apply_transforms(fixed=imgAntsBase, moving=imgAntsMovDose,
                                                            transformlist=regFile)

                else:
                    warpedDoseImage = ants.apply_transforms(fixed=imgAntsBase, moving=imgAntsMovDose,
                                                            transformlist=result['fwdtransforms'])
                warpedDoseImage.image_write(warpedDoseFileName)

        warpedDoseName = 'warped_' + doseFileName.split('.')[0]
        planC = pc.loadNiiDose(warpedDoseFileName, movScanNum, planC, fractionGroupID=warpedDoseName)
        warpedDoseNum = len(planC.dose)-1
        warpedDoseNumV.append(warpedDoseNum)


    # Record registration performance metrics
    diceV = []
    hd95V = []
    spacing_mm = planC.scan[baseScanNum].getScanSpacing() * 10
    for idx in range(len(warpedStrNumV)):
        M1 = getStrMask(baseStrNumV[idx], planC)
        M2 = getStrMask(warpedStrNumV[idx], planC)
        surf_dists = compute_surface_distances(M1, M2, spacing_mm)
        diceV.append(compute_dice_coefficient(M1, M2))
        hd95V.append(compute_robust_hausdorff(surf_dists, 95))

    return warpedScanNum, warpedStrNumV, warpedDoseNumV, diceV, hd95V, planC

## Apply ***registerScansANTS***
Inputs: DICOM scan, associated RTSTRUCT, RTDOSE
Output: pyCERR `planC` with warped scan, structures, and doses and registration quality metrics (DSC, HD95).

In [None]:
dcmDir = '/path/to/input/dataset'
tmpDir = '/path/for/intermediate/files'
# Registration options
transformType = 'SyN',

# Load DICOM data
planC = pc.loadDcmDir(dcmDir)
roiDict = {'Left_masseter': 1, 'Right_masseter':2, 'Left_medial_pterygoid':3, 'Right_medial_pterygoid':4}

# Export relevant files to NIfTI
baseScanNum = 0
movScanNum = 1
baseStrNumV = [1,2]
movStructNumV = [4,5]
baseDoseNum = 0
movDoseNum = 1
baseScanFile = os.path.join(tmpDir, 'base_scan.nii.gz')
movScanFile = os.path.join(tmpDir, 'mov_scan.nii.gz')
baseDoseFile = os.path.join(tmpDir, 'base_dose.nii.gz')
movDoseFile = os.path.join(tmpDir, 'mov_dose.nii.gz')

planC.scan[baseScanNum].saveNii(baseScanFile)
planC.scan[movScanNum].saveNii(movScanFile)
planC.dose[baseDoseNum].saveNii(baseDoseFile)
planC.dose[movDoseNum].saveNii(movDoseFile)

# Register scans and deform associated objects 
planC, warpedScanNumV, warpedStrNumV, warpedDoseNumV, diceV, hd95V = registerScansANTS(planC,
                                                                                       transformType
                                                                                       baseScanFile,
                                                                                       movScanFile,
                                                                                       movDoseFile,
                                                                                       baseStrNumV,
                                                                                       movStructNumV,
                                                                                       roiDict,
                                                                                       tmpDir)