In [None]:
from datetime import datetime
from pathlib import Path
import slicer
import vtk
from DICOMLib import DICOMUtils
from SimpleFilters import SimpleFiltersLogic
from SimpleITK.SimpleITK import BinaryFillholeImageFilter

In [None]:
dicom_folder = Path(r"C:\Met Recurrence\DatasetMRIsFixed\DCMs")
output_folder = Path(r"C:\Met Recurrence\DatasetMRIsFixed\MRIandContoursCoregistered")

In [None]:
def removeHoles(labelmap_node):
    my_filter = BinaryFillholeImageFilter()
    my_filter.SetDebug(False)
    my_filter.SetFullyConnected(False)
    my_filter.SetNumberOfThreads(12)
    my_filter.SetNumberOfWorkUnits(0)
    my_filter.SetForegroundValue(1)
    inputs = [labelmap_node]
    output = labelmap_node
    logic = SimpleFiltersLogic()
    slicer.modules.simplefilters.widgetRepresentation()
    logic.run(my_filter,
        output,
        True,
        *inputs)
    while logic.thread.is_alive():
        print("Filter still running")
    print("Filter done")

def cloneNode(node):
    shNode = slicer.vtkMRMLSubjectHierarchyNode.GetSubjectHierarchyNode(slicer.mrmlScene)
    item_id_to_clone = shNode.GetItemByDataNode(node)
    cloned_item_id = slicer.modules.subjecthierarchy.logic().CloneSubjectHierarchyItem(shNode, item_id_to_clone)
    cloned_node = shNode.GetItemDataNode(cloned_item_id)
    return cloned_node

def loadDicomsInDatabase(patient_path):
    DICOMUtils.clearDatabase()
    database = slicer.dicomDatabase  # using the main database because the TemporaryDICOMDatabase might crash slicer
    DICOMUtils.importDicom(str(patient_path), database)
    slicer.modules.dicom.widgetRepresentation()  # creates browserWidget
    browser_widget = slicer.modules.DICOMWidget.browserWidget
    browser_widget.dicomBrowser.dicomTableManager().patientsTable().selectAll()
    browser_widget.loadButton.clicked()  # have to load DICOMs this kludgy way b/c DICOMUtils.loadPatientByUID gives an error

def clearScene():
    slicer.mrmlScene.Clear(False)
    slicer.mrmlScene.SetURL("")
    slicer.app.coreIOManager().setDefaultSceneFileType("MRML Scene (.mrml)")

In [None]:
for patient_path in dicom_folder.iterdir():
    if patient_path.is_file():
        continue
    if (output_folder / patient_path.name).exists():
        continue

    clearScene()

    loadDicomsInDatabase(patient_path)

    # Find segmentation node
    segmentation_node = slicer.util.getNodesByClass('vtkMRMLSegmentationNode')[0]

    segmentation_node.SetSourceRepresentationToClosedSurface()  # this prevents glitchy holes when hardening the transforms later on

    # Find MRI volumes
    mri_volumes = []
    mri_acq_dates = []
    volume_nodes = slicer.util.getNodesByClass('vtkMRMLScalarVolumeNode')

    for node in volume_nodes:
        if "Mr2" in node.GetName():
            continue  # we don't deal with pre contrast mris for now
        inst_uids = node.GetAttribute("DICOM.instanceUIDs").split()
        filename = slicer.dicomDatabase.fileForInstance(inst_uids[0])
        modality = slicer.dicomDatabase.fileValue(filename, "0008,0060")  # Modality
        if modality != "MR":
            continue
        mri_volumes.append(node)
        acq_date = slicer.dicomDatabase.fileValue(filename, "0008,0022")  # Acquisition Date
        mri_acq_dates.append(datetime.strptime(acq_date, "%Y%m%d"))

    mri_volumes_sorted = [v for _, v in sorted(zip(mri_acq_dates, mri_volumes))]

    print("Registering contours to older MRIs")
    (output_folder / patient_path.name).mkdir()
    slicer.app.processEvents()  # have to do this or screen doesn't update until script is done running
    for i, volume in enumerate(mri_volumes_sorted[:-1]):
        timepoint_name = "6month" if i==0 else "3month"
        slicer.util.saveNode(volume, str(output_folder / patient_path.name / f"{timepoint_name}_contrastMRI.nrrd"))

        transform_node = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLLinearTransformNode')

        parameters = {
            'fixedVolume': volume,
            'movingVolume': mri_volumes_sorted[-1],
            'outputTransform': transform_node,
            'useRigid': True,
            'initializeTransformMode': 'useMomentsAlign'
        }
        slicer.cli.runSync(slicer.modules.brainsfit, parameters=parameters)
        mri_volumes_sorted[-1].SetAndObserveTransformNodeID('')  # we don't actually want the latest volume to be transformed

        # Clone the segmentation node and transform the clone
        segmentation_node_clone = cloneNode(segmentation_node)
        segmentation_node_clone.SetName(timepoint_name)
        segmentation_node_clone.SetAndObserveTransformNodeID(transform_node.GetID())
        segmentation_node_clone.HardenTransform()

        segmentation_node_clone.SetReferenceImageGeometryParameterFromVolumeNode(volume)

        segmentation_node_clone.CreateBinaryLabelmapRepresentation()
        segmentation_node_clone.SetSourceRepresentationToBinaryLabelmap()

        # Export each segment to a labelmap volume and save
        segment_ids = [segmentation_node_clone.GetSegmentation().GetNthSegmentID(i) for i in
                        range(segmentation_node_clone.GetSegmentation().GetNumberOfSegments())]

        for segment_id in segment_ids:
            print(f"Converting {segment_id} to nrrd")

            # Create output labelmap volume
            labelmap_node = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLLabelMapVolumeNode')
            labelmap_name = f"{timepoint_name}_{segment_id}_label"
            labelmap_node.SetName(labelmap_name)

            single_segment_array = vtk.vtkStringArray()
            single_segment_array.InsertNextValue(segment_id)
            slicer.vtkSlicerSegmentationsModuleLogic.ExportSegmentsToLabelmapNode(segmentation_node_clone, single_segment_array, labelmap_node, volume)

            if segment_id.lower() == "brain":
                removeHoles(labelmap_node)

            filename = labelmap_name + '.nrrd'

            slicer.util.saveNode(labelmap_node, str(output_folder / patient_path.name / filename))
    break