In [1]:
# notebook to run tests for vnet segmentation 
# assuming first cell (git pulling leaky_vnet) from testme_cnn.ipynb has been run

# imports

import sys
sys.path.insert(0, '../00_support_functions/') # add support function directory to path
from pyfunctions import *
from leaky_vnet.model import leaky_vnet as vnet

import torch

def seg_cnn(input, output_mask, output_mask_l, output_mask_r, output_box_l, output_box_r, thresh=0.65, logfile=""):

    # prep logs
    if logfile == "":
        stdo = subprocess.DEVNULL
        stde = subprocess.DEVNULL
    else:
        stde = subprocess.STDOUT
        stdo = open(logfile, 'w')

    # load vnet
    print("loading network ...")
    bmriprep_cnn_seg = vnet.VNet()
    model = bmriprep_cnn_seg.from_pretrained("doggywastaken/bmri-prep_cnn_seg")
    print("..done!")
    
    print("load input dataset ...")
    ref_path = input
    ref_left = load_nii(ref_path, "left", (64,256,256))
    ref_right = load_nii(ref_path, "right", (64,256,256))
    ref = np.concatenate([ref_left, np.flip(ref_right, 2)], axis=2)
    print("..done!")

    # perform inference
    print("performing inference ...")
    inp = torch.from_numpy(ref).view(1,1,64,256,256).float()
    pred = model( inp ).detach().numpy().squeeze()
    print("... done!")

    print("post-processing ...")
    pred = (pred-np.min(pred))/(np.max(pred)-np.min(pred))
    mask = np.where(pred>thresh, 1, 0)
    print("... done!")
    
    # saving stuff
    print("saving mask ...")
    npy_resize_nii(mask, input, output_mask)

    print("calculating bounding-box ...")
    rs_mask = resize(mask, (112, 512, 512))
    msk_l = npy_bbox_nii(rs_mask, input, output_box_l, 'left')
    msk_r = npy_bbox_nii(rs_mask, input, output_box_r, 'right')

    print("saving bounding box masks ...")
    npy_resize_nii(msk_l, input, output_mask_l)
    npy_resize_nii(msk_r, input, output_mask_r)
    print("...done!")

    print("all done!")

In [2]:
test_input =                    r'../evaldata/sub-001/ses-01/sub-001_ses-01_ref.nii'
test_output =                   r'evalresults_cnn/sub-001_ses-01_ref_Mask.nii'
test_output_left =              r'evalresults_cnn/sub-001_ses-01_ref_Maskl.nii'
test_output_right =             r'evalresults_cnn/sub-001_ses-01_ref_Maskr.nii'
test_output_bbox_left =         r'evalresults_cnn/sub-001_ses-01_ref_Segl.nii'
test_output_bbox_right =        r'evalresults_cnn/sub-001_ses-01_ref_Segr.nii'

seg_cnn(test_input,
        test_output, 
        test_output_left, 
        test_output_right,
        test_output_bbox_left, 
        test_output_bbox_right)
        
test_input =                    r'../evaldata/sub-011/ses-01/sub-011_ses-01_ref.nii'
test_output =                   r'evalresults_cnn/sub-011_ses-01_ref_Mask.nii'
test_output_left =              r'evalresults_cnn/sub-011_ses-01_ref_Maskl.nii'
test_output_right =             r'evalresults_cnn/sub-011_ses-01_ref_Maskr.nii'
test_output_bbox_left =         r'evalresults_cnn/sub-011_ses-01_ref_Segl.nii'
test_output_bbox_right =        r'evalresults_cnn/sub-011_ses-01_ref_Segr.nii'

seg_cnn(test_input,
        test_output, 
        test_output_left, 
        test_output_right,
        test_output_bbox_left, 
        test_output_bbox_right)

loading network ...
..done!
load input dataset ...
..done!
performing inference ...
... done!
post-processing ...
... done!
saving mask ...
calculating bounding-box ...
saving bounding box masks ...
...done!
all done!
loading network ...
..done!
load input dataset ...
..done!
performing inference ...
... done!
post-processing ...
... done!
saving mask ...
calculating bounding-box ...
saving bounding box masks ...
...done!
all done!


In [3]:
test_input =                    r'../evaldata/sub-021/ses-01/sub-021_ses-01_ref.nii'
test_output =                   r'evalresults_cnn/sub-021_ses-01_ref_Mask.nii'
test_output_left =              r'evalresults_cnn/sub-021_ses-01_ref_Maskl.nii'
test_output_right =             r'evalresults_cnn/sub-021_ses-01_ref_Maskr.nii'
test_output_bbox_left =         r'evalresults_cnn/sub-021_ses-01_ref_Segl.nii'
test_output_bbox_right =        r'evalresults_cnn/sub-021_ses-01_ref_Segr.nii'

seg_cnn(test_input,
        test_output, 
        test_output_left, 
        test_output_right,
        test_output_bbox_left, 
        test_output_bbox_right)
        
test_input =                    r'../evaldata/sub-022/ses-01/sub-022_ses-01_ref.nii'
test_output =                   r'evalresults_cnn/sub-022_ses-01_ref_Mask.nii'
test_output_left =              r'evalresults_cnn/sub-022_ses-01_ref_Maskl.nii'
test_output_right =             r'evalresults_cnn/sub-022_ses-01_ref_Maskr.nii'
test_output_bbox_left =         r'evalresults_cnn/sub-022_ses-01_ref_Segl.nii'
test_output_bbox_right =        r'evalresults_cnn/sub-022_ses-01_ref_Segr.nii'

seg_cnn(test_input,
        test_output, 
        test_output_left, 
        test_output_right,
        test_output_bbox_left, 
        test_output_bbox_right)

test_input =                    r'../evaldata/sub-030/ses-01/sub-030_ses-01_ref.nii'
test_output =                   r'evalresults_cnn/sub-030_ses-01_ref_Mask.nii'
test_output_left =              r'evalresults_cnn/sub-030_ses-01_ref_Maskl.nii'
test_output_right =             r'evalresults_cnn/sub-030_ses-01_ref_Maskr.nii'
test_output_bbox_left =         r'evalresults_cnn/sub-030_ses-01_ref_Segl.nii'
test_output_bbox_right =        r'evalresults_cnn/sub-030_ses-01_ref_Segr.nii'

seg_cnn(test_input,
        test_output, 
        test_output_left, 
        test_output_right,
        test_output_bbox_left, 
        test_output_bbox_right)

loading network ...
..done!
load input dataset ...
..done!
performing inference ...
... done!
post-processing ...
... done!
saving mask ...
calculating bounding-box ...
saving bounding box masks ...
...done!
all done!
loading network ...
..done!
load input dataset ...
..done!
performing inference ...
... done!
post-processing ...
... done!
saving mask ...
calculating bounding-box ...
saving bounding box masks ...
...done!
all done!
loading network ...
..done!
load input dataset ...
..done!
performing inference ...
... done!
post-processing ...
... done!
saving mask ...
calculating bounding-box ...
saving bounding box masks ...
...done!
all done!


In [4]:
test_input =                    r'../testdata/ds02/ds02_ref.nii'
test_output =                   r'evalresults_cnn/ds02_ref_Mask.nii'
test_output_left =              r'evalresults_cnn/ds02_ref_Maskl.nii'
test_output_right =             r'evalresults_cnn/ds02_ref_Maskr.nii'
test_output_bbox_left =         r'evalresults_cnn/ds02_ref_Segl.nii'
test_output_bbox_right =        r'evalresults_cnn/ds02_ref_Segr.nii'

seg_cnn(test_input, 
        test_output, 
        test_output_left, 
        test_output_right,
        test_output_bbox_left, 
        test_output_bbox_right)

test_input =                    r'../testdata/ds03/ds03_ref.nii'
test_output =                   r'evalresults_cnn/ds03_ref_Mask.nii'
test_output_left =              r'evalresults_cnn/ds03_ref_Maskl.nii'
test_output_right =             r'evalresults_cnn/ds03_ref_Maskr.nii'
test_output_bbox_left =         r'evalresults_cnn/ds03_ref_Segl.nii'
test_output_bbox_right =        r'evalresults_cnn/ds03_ref_Segr.nii'

seg_cnn(test_input, 
        test_output, 
        test_output_left, 
        test_output_right,
        test_output_bbox_left, 
        test_output_bbox_right)

loading network ...
..done!
load input dataset ...
..done!
performing inference ...
... done!
post-processing ...
... done!
saving mask ...
calculating bounding-box ...
saving bounding box masks ...
...done!
all done!
loading network ...
..done!
load input dataset ...
..done!
performing inference ...
... done!
post-processing ...
... done!
saving mask ...
calculating bounding-box ...
saving bounding box masks ...
...done!
all done!
