In [1]:
import scipy
import numpy as np
from tensorflow.keras.models import load_model

In [None]:
def Stroke_closing(img):
    # used to close stroke prediction image
    new_img = np.zeros_like(img)
    new_img = scipy.ndimage.morphology.binary_closing(img, structure=np.ones((2,2,2)))
    return new_img

In [None]:
def get_MaskNet_MNI(model, Dwi_MNI_img, B0_MNI_img):
    # To inference brain mask from MaskNet model
    # model specifies which pre-trained DL model is used to inference
    # Dwi_MNI_img and B0_MNI_img are input images in MNI domain
    
    # Down sampling
    dwi = Dwi_MNI_img[0::4,0::4,0::4,np.newaxis] # Down sample for MaskNet, dim should be [48, 56, 48, 1]
    dwi  = (dwi-np.mean(dwi))/np.std(dwi)

    b0 = B0_MNI_img[0::4,0::4,0::4, np.newaxis] # Down sample for MaskNet, dim should be [48, 56, 48, 1]
    b0  = (b0-np.mean(b0))/np.std(b0)
    x = np.expand_dims(np.concatenate((dwi,b0),axis=3), axis=0)

    # inference
    y_pred = model.predict(x, verbose=0)
    y_pred = (np.squeeze(y_pred)>0.5)*1.0

    
    # the following is post processing of predicted mask by 
    # 1) selecting the major non-zero voxel
    # 2) closing
    # 3) binary fill holes
    # 4) upsampling to high resolution space by (4,4,4)
    
    mask_label, num_features = scipy.ndimage.label(y_pred)
    dilate_mask = (mask_label == scipy.stats.mode(mask_label[mask_label>0].flatten(), keepdims=True)[0][0])*1
    dilate_mask = Stroke_closing(dilate_mask)
    dilate_mask = scipy.ndimage.morphology.binary_fill_holes(dilate_mask)
    upsampling_mask = np.repeat(np.repeat(np.repeat(dilate_mask, 4, axis=0), 4, axis=1), 4, axis=2)

    return upsampling_mask

In [None]:
MaskNet = load_model(MaskNet_name, compile=False)

In [None]:
mask_MNI_img = get_MaskNet_MNI(MaskNet, Dwi_MNI_img, B0_MNI_img)
mask_raw_img = affine_map.transform_inverse((mask_MNI_img>0.5)*1, interpolation='nearest')
mask_raw_img = (mask_raw_img>0.5)*1.0

if generate_brainmask:
    mask_raw_ImgJ = get_new_NibImgJ(mask_raw_img, Dwi_imgJ, dataType=np.int16)
    nib.save(mask_raw_ImgJ, os.path.join(SubjDir, SubjID + '_Mask.nii.gz'))