In [1]:
import os
import cv2
import time
import datetime
import numpy as np
import nibabel as nib
from glob import glob
from skimage import measure
from utils.config import cfg
from tasks.aneurysm.nets.aneurysm_net import DAResUNet
from tasks.aneurysm.datasets.aneurysm_dataset import ANEURYSM_SEG

import torch
import torch.backends.cudnn as cudnn

In [2]:
# params init
cfg_path = 'tasks/configs/aneurysm_seg.daresunet.yaml'
model_path = 'raws/weight/da_resunet.pth.tar'
cfg.merge_from_file(cfg_path)


gpus = [0,1]
use_gpu = ','.join(map(str, gpus))
os.environ["CUDA_VISIBLE_DEVICES"] = use_gpu
print('USE GPU: ', use_gpu)

WORKERS = 8
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


WL, WW = 450, 900
print('USE WL_WW:', WL, WW)

# evaluation data list 
data_root = './raws'
mask_root = '%s/mask' % data_root
image_root = '%s/image' % data_root
test_pids_file = '@./raw/part_test.txt'
test_pids_list = ['example']

# segmentation save
STORE_MASK = True
SAVE_DIR = './raws/output'

#
VIS_ROOT = './raws/vis'

cfg.TASK.STATUS = 'test'
cfg.TEST.DATA.NII_FOLDER = image_root
cfg.TRAIN.DATA.WL_WW = WL, WW

cudnn.benchmark = True
cudnn.deterministic = True

USE GPU:  0,1
USE WL_WW: 450 900


In [3]:
## model init
net = DAResUNet(segClasses=2, k=32, input_channel=1)
params = torch.load(model_path, map_location='cpu')
net.load_state_dict(params['model'])
print('current model epoch=%d' % params['epoch'])
net = net.to(device)
if len(gpus) > 1:
    net = torch.nn.DataParallel(net, device_ids=range(len(gpus)))
net.eval()

current model epoch=68


DataParallel(
  (module): DAResUNet(
    (layer0): CBR(
      (conv): Conv3d(1, 32, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False)
      (bn): BatchNorm3d(32, eps=0.001, momentum=0.95, affine=True, track_running_stats=True)
      (act): ReLU(inplace)
    )
    (class0): Sequential(
      (0): BasicBlock(
        (c1): CBR(
          (conv): Conv3d(96, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (bn): BatchNorm3d(64, eps=0.001, momentum=0.95, affine=True, track_running_stats=True)
          (act): ReLU(inplace)
        )
        (c2): CB(
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (bn): BatchNorm3d(64, eps=0.001, momentum=0.95, affine=True, track_running_stats=True)
        )
        (act): ReLU(inplace)
        (downsample): Sequential(
          (0): Conv3d(96, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
          (1): BatchNorm3d(64, 

In [4]:
## data init
EVAL_FILES = None
EVAL_FILES = test_pids_file
EVAL_LIST = []
EVAL_LIST = test_pids_list
if EVAL_FILES is not None and os.path.exists(EVAL_FILES):
    with open(EVAL_FILES, 'r') as f:
        lines = f.readlines()
        subjects = [line.strip() for line in lines]
elif EVAL_LIST is not None:
    subjects = EVAL_LIST
else:
    subjects = []
subjects.sort()
print("the eval data count: %d" % len(subjects))

the eval data count: 1


In [5]:
## eval
eval_results = {}
cuda_times = []
with torch.no_grad():
    for step, subject in enumerate(subjects, start=1):
        para_dict = {"subject": subject}
        eval_set = ANEURYSM_SEG(para_dict, "test")
        kwargs = {'shuffle': False, 'pin_memory': True, 
                  'drop_last': False, 'batch_size': BATCH_SIZE, 'num_workers': WORKERS}
        data_loader = torch.utils.data.DataLoader(eval_set, **kwargs)
        v_x, v_y, v_z = eval_set.volume_size()
        other_infos = eval_set.get_other_infos()  # for seg eval

        seg = torch.FloatTensor(v_x, v_y, v_z).zero_()
        seg = seg.to(device)

        time_start = time.time()
        for i, (image, coord) in enumerate(data_loader):
            image = image.to(device)
            out = net(image)
            pred = torch.nn.functional.softmax(out['y'], dim=1)
            for idx in range(image.size(0)):
                sx, ex = coord[idx][0][0], coord[idx][0][1]
                sy, ey = coord[idx][1][0], coord[idx][1][1]
                sz, ez = coord[idx][2][0], coord[idx][2][1]

                seg[sx:ex, sy:ey, sz:ez] += pred[idx][1]  # accsum

        time_end = time.time()
        seg = (seg >= 0.30).cpu().numpy().astype(np.uint8)  # binary, mask

        if STORE_MASK:
            eval_set.save(seg.copy(), SAVE_DIR)

        eval_results[subject] = np.transpose(seg, (1, 2, 0)), other_infos  # (z.x,y) => (x,y,z)
        cuda_time = time_end - time_start
        cuda_times += [cuda_time]
        print(datetime.datetime.now(),\
              '%d/%d: %s finished! cuda time=%.2f s!' % (step, len(subjects), subject, cuda_time))
        torch.cuda.empty_cache()
    print('total data: %.2s,avg cuda time: %.2s' % \
          (len(cuda_times), 1. * sum(cuda_times) / len(cuda_times)))


2020-08-27 00:21:05,590 INFO    [base_dataset.py, 196] stage:test load:1008 nums!


2020-08-27 00:21:44.358332 1/1: example finished! cuda time=37.48 s!
total data: 1,avg cuda time: 37


In [6]:
def set_window_wl_ww(tensor, wl=225, ww=450):
    w_min, w_max = wl - ww // 2, wl + ww // 2
    tensor[tensor < w_min] = w_min
    tensor[tensor > w_max] = w_max
    tensor = ((1.0 * (tensor - w_min) / (w_max - w_min)) * 255).astype(np.uint8)

    return tensor

In [7]:
## visualization
for subject in subjects:
    vis_save_dir = '%s/%s' % (VIS_ROOT, subject)
    os.makedirs(vis_save_dir, exist_ok=True)
    seg = eval_results[subject][0].astype(np.uint8)
    
    mask_path = '%s/%s_mask.nii.gz' % (mask_root, subject)
    mask = nib.load(mask_path).get_data().astype(np.uint8)
    
    image_path = '%s/%s.nii.gz' % (image_root, subject)
    image = nib.load(image_path).get_data()

    vis_indices_seg = np.where(seg.sum(axis=(0, 1)))[0].tolist()
    vis_indices_mask = np.where(mask.sum(axis=(0, 1)))[0].tolist()
    vis_all_indices = list(set(vis_indices_seg + vis_indices_mask))

    ww, wl = 800, 300

    vis_all_select_image = {}
    # vis all select slice
    for slice_index in vis_all_indices:
        vis_all_select_image[slice_index] = []
        # 
        slice_image = image[..., slice_index]
        slice_image = set_window_wl_ww(slice_image, wl=wl, ww=ww)
        rgb_slice_image = np.stack((slice_image, slice_image, slice_image), 2).copy()  # x, y, c=3
        rgb_slice_image = rgb_slice_image.transpose(1,0,2).copy()
        vis_info = 'IMG %04d' % (slice_index,)
        cv2.putText(rgb_slice_image, vis_info, (4, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.72, (255, 255, 0), 2)
        vis_all_select_image[slice_index] += [rgb_slice_image]


    # vis mask select slice
    for slice_index in vis_all_indices:
        rgb_slice_image = vis_all_select_image[slice_index][0].copy()
        slice_bin_mask = mask[..., slice_index].astype(np.uint8).copy()
        contours, hierarchy = cv2.findContours(slice_bin_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
        contours = [_[..., ::-1] for _ in contours]
        cv2.drawContours(rgb_slice_image, contours, -1, color=(0, 0, 255), thickness=1)

        vis_info = 'MASK %04d' % (slice_index,)
        cv2.rectangle(rgb_slice_image, (0, 0), (150, 32), color=(100, 150,0), thickness=cv2.FILLED)
        cv2.putText(rgb_slice_image, vis_info, (4, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.72, (255, 255, 0), 2)
        vis_all_select_image[slice_index] += [rgb_slice_image]

    # vis seg select slice
    for slice_index in vis_all_indices:
        rgb_slice_image = vis_all_select_image[slice_index][0].copy()
        slice_bin_seg = seg[..., slice_index].astype(np.uint8).copy()
        contours, hierarchy = cv2.findContours(slice_bin_seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
        contours = [_[..., ::-1] for _ in contours]
        cv2.drawContours(rgb_slice_image, contours, -1, color=(0, 0, 255), thickness=1)

        vis_info = 'SEG %04d' % (slice_index,)
        cv2.rectangle(rgb_slice_image, (0, 0), (150, 32), color=(200,150,0), thickness=cv2.FILLED)
        cv2.putText(rgb_slice_image, vis_info, (4, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.72, (255, 255, 0), 2)
        vis_all_select_image[slice_index] += [rgb_slice_image]

    for slice_index, rgb_slice_images in vis_all_select_image.items():
        save_path = '%s/%03d.png' % (vis_save_dir, slice_index)
        slice_image_merge = np.hstack(rgb_slice_images)
        cv2.imwrite(save_path, slice_image_merge)