In [None]:
import slicer
import vtk
import numpy as np
import os

# **Define input and output root directories**
input_root_dir = '/'  
output_root_dir = '/' 

patient_ids = [d for d in os.listdir(input_root_dir) if os.path.isdir(os.path.join(input_root_dir, d))]

for patient_id in patient_ids:
    print(f'in processing {patient_id} ...')    
    
    # **Create patient input and output paths**
    patient_input_dir = os.path.join(input_root_dir, patient_id)
    patient_output_dir = os.path.join(output_root_dir, patient_id)
    
    if not os.path.exists(patient_output_dir):
        os.makedirs(patient_output_dir)
    
    dwiFilePath = os.path.join(patient_input_dir, 'DWI_MR0_BET.nii.gz')
    
    # **Define  MASK file names)**
    maskFileNames = ['DWI_mask_MR0.nii.gz', 'DWI_mask_MR0_thr.nii.gz']
    
    maskFilePath = None
    
    for maskFileName in maskFileNames:
        possibleMaskFilePath = os.path.join(patient_input_dir, maskFileName)
        if os.path.exists(possibleMaskFilePath):
            maskFilePath = possibleMaskFilePath
            break  
    
    if not os.path.exists(dwiFilePath):
        print(f'patient {patient_id}  DWI images not exist，pass')
        continue
    
    dwiNode = slicer.util.loadVolume(dwiFilePath)
    if not dwiNode:
        print(f'can not load patient {patient_id} DWI，pass')
        continue
    
    mask_exists = maskFilePath is not None
    
    if mask_exists:
        
        maskNode = slicer.util.loadLabelVolume(maskFilePath)
        if not maskNode:
            print(f'can not load patient {patient_id} mask，will use the image center')
            mask_exists = False
    else:
        print(f'patient {patient_id} mask not exist，will use the image center')
    
   # **Calculate the view center**
    if mask_exists:
        def computeMaskCentroid(maskNode):
            maskArray = slicer.util.arrayFromVolume(maskNode)
            indices = np.argwhere(maskArray > 0)
            if indices.size == 0:
                raise Exception('There are no non-zero pixels in the mask')
            ijk_center = indices.mean(axis=0)
            ijk_center = ijk_center[::-1] 
            ijkToRAS = vtk.vtkMatrix4x4()
            maskNode.GetIJKToRASMatrix(ijkToRAS)
            ijk_point = [ijk_center[0], ijk_center[1], ijk_center[2], 1.0]
            ras_point = ijkToRAS.MultiplyPoint(ijk_point)
            return ras_point[:3] 
        
       
        try:
            ras_point = computeMaskCentroid(maskNode)
        except Exception as e:
            print(f'patient {patient_id} center of mass cannot be calculated , so the center of the image will be used as the view center.')
            mask_exists = False
    
    if not mask_exists:
        def getVolumeCenter(volumeNode):
            imageData = volumeNode.GetImageData()
            dims = imageData.GetDimensions()
            ijk_center = [dims[0]/2.0, dims[1]/2.0, dims[2]/2.0]
            ijkToRAS = vtk.vtkMatrix4x4()
            volumeNode.GetIJKToRASMatrix(ijkToRAS)
            ijk_point = [ijk_center[0], ijk_center[1], ijk_center[2], 1.0]
            ras_point = ijkToRAS.MultiplyPoint(ijk_point)
            return ras_point[:3]
        
        ras_point = getVolumeCenter(dwiNode)
    
    slicer.util.setSliceViewerLayers(background=dwiNode)
    
    if mask_exists:
        slicer.util.setSliceViewerLayers(label=maskNode)
        colorNode = slicer.util.getNode('GenericAnatomyColors')  
        maskNode.GetDisplayNode().SetAndObserveColorNodeID(colorNode.GetID())
    
    def jumpToRAS(ras_point):
        sliceNodes = slicer.util.getNodesByClass('vtkMRMLSliceNode')
        for sliceNode in sliceNodes:
            sliceNode.JumpSliceByCentering(ras_point[0], ras_point[1], ras_point[2])
    
    jumpToRAS(ras_point)
    layoutManager = slicer.app.layoutManager()
    slicer.util.setSliceViewerLayers(label=None)
    slicer.app.processEvents()
    
    def hideSliceViewAnnotations():
        for sliceViewName in layoutManager.sliceViewNames():
            sliceWidget = layoutManager.sliceWidget(sliceViewName)
            sliceView = sliceWidget.sliceView() 
            cornerAnnotation = sliceView.cornerAnnotation()
            cornerAnnotation.SetMaximumLineHeight(0)
            for i in range(4):
                cornerAnnotation.SetText(i, "")
            cornerAnnotation.GetTextProperty().SetFontSize(0)
            cornerAnnotation.Modified()
            sliceView.forceRender()
    
    hideSliceViewAnnotations() 
    sliceViewNames = ['Red', 'Yellow', 'Green']
    saveDirectory = patient_output_dir
    
    for sliceViewName in sliceViewNames:
        sliceWidget = layoutManager.sliceWidget(sliceViewName)
        sliceView = sliceWidget.sliceView()
    
        sliceView.forceRender()
        image = sliceView.grab()
        savePath = os.path.join(saveDirectory, f'{sliceViewName}_slice.png')
    
        image.save(savePath)
    
        print(f'saved patient {patient_id}  {sliceViewName} view to {savePath}')
    
    if mask_exists:
        slicer.util.setSliceViewerLayers(label=maskNode)  
    slicer.app.processEvents()
    
    def showSliceViewAnnotations():
        for sliceViewName in layoutManager.sliceViewNames():
            sliceWidget = layoutManager.sliceWidget(sliceViewName)
            sliceView = sliceWidget.sliceView()
            cornerAnnotation = sliceView.cornerAnnotation()
            cornerAnnotation.SetMaximumLineHeight(1.0)
            cornerAnnotation.GetTextProperty().SetFontSize(12)
            cornerAnnotation.Modified()
            sliceView.forceRender()
    
    showSliceViewAnnotations()
    
    slicer.mrmlScene.Clear(0)  
    
    print(f'finished patient {patient_id}\n')