In [1]:
import nibabel as nib
import numpy as np

import SimpleITK as sitk
import os
import glob

import matplotlib.pyplot as plt
import nibabel as nib

import argparse
import math
import os
from distutils.version import LooseVersion

import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from config import TrainGlobalConfig
from dataset import BraTSDataset
from model import ModelBuilder
from utils import AverageMeter, PytorchTrainer

from transforms import Rot90, Flip, Identity, Compose
from transforms import GaussianBlur, Noise, Normalize, RandSelect
from transforms import RandCrop, CenterCrop, Pad,RandCrop3D,RandomRotion,RandomFlip,RandomIntensityChange
from transforms import NumpyType

from vis_utils import myshow


In [2]:
    device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
    
    builder = ModelBuilder()
    model = builder.build_net(
            arch=TrainGlobalConfig.id, 
            num_input=TrainGlobalConfig.num_input, 
            num_classes=TrainGlobalConfig.num_classes, 
            num_branches=TrainGlobalConfig.num_branches,
            padding_list=TrainGlobalConfig.padding_list, 
            dilation_list=TrainGlobalConfig.dilation_list
            )

    model = model.to(device)

    ch = torch.load("/home/koga/workspace/autofucus_layer/result/ex5/last-checkpoint.bin")
    model.load_state_dict(ch["model_state_dict"])

                    
    train_list = os.path.join("/home/koga/dataset/BRATS2018/Train", "train_0.txt")
    valid_list = os.path.join("/home/koga/dataset/BRATS2018/Train", "valid_0.txt")

    valid_dataset = BraTSDataset(
        list_file = valid_list,
        root = TrainGlobalConfig.root_path,
        phase = "val",
        transforms = None
    )
    
  
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, 
        batch_size=TrainGlobalConfig.batch_size, 
        shuffle=False, 
        num_workers=TrainGlobalConfig.num_workers, 
        pin_memory=True
    )


In [3]:
index = 0
image = valid_dataset[index]["images"].unsqueeze(0)
label = valid_dataset[index]["labels"].unsqueeze(0)
print(label.unique())

tensor([0, 1, 2, 3])


In [4]:
label.shape

torch.Size([1, 75, 75, 75])

In [5]:
model.eval()
with torch.no_grad():
    image = image.to(device, dtype=torch.float)
    label = label.to(device, dtype=torch.long)

    out = model(image)
    out = out.permute(0,2,3,4,1)
    out = torch.max(out, 4)[1]

    start = [14, 14, 14]
    center = out.size()[1:]
    prediction = torch.zeros(label.size())
    prediction[:, start[0]:start[0]+center[0], start[1]: start[1]+center[1], start[2]: start[2]+center[2]] = out

    p = prediction    
    print(p.unique())            
            # make the prediction corresponding to the center part of the image
    prediction = prediction.contiguous().view(-1).cuda()            
    label = label.contiguous().view(-1) 

tensor([0., 2.])


In [7]:
myshow(sitk.GetImageFromArray(p.squeeze(0).detach().numpy()))

interactive(children=(IntSlider(value=119, description='z', max=239), Output()), _dom_classes=('widget-interac…

In [64]:
def dice_score(o, t,eps = 1e-8):
    num = 2*(o*t).sum() + eps #
    den = o.sum() + t.sum() + eps # eps
    # print(o.sum(),t.sum(),num,den)
    print('All_voxels:240*240*155 | numerator:{} | denominator:{} | pred_voxels:{} | GT_voxels:{}'.format(int(num),int(den),o.sum(),int(t.sum())))
    return num/den


def softmax_output_dice(output, target):
    ret = []

    # whole
    o = output > 0; t = target > 0 # ce
    ret += dice_score(o, t),
    # core
    o = (output==1) | (output==3)
    t = (target==1) | (target==4)
    ret += dice_score(o , t),
    # active
    o = (output==3); t = (target==4)
    ret += dice_score(o , t),

    return ret

In [65]:
label = valid_dataset[index]["labels"].unsqueeze(0)
scores = softmax_output_dice(p, label)
keys = 'whole', 'core', 'enhancing', 'loss'

msg = ', '.join(['{}: {:.4f}'.format(k, v) for k, v in zip(keys, scores)])
print(msg)


All_voxels:240*240*155 | numerator:135432 | denominator:157911 | pred_voxels:74493 | GT_voxels:83418
All_voxels:240*240*155 | numerator:440 | denominator:8872 | pred_voxels:1237 | GT_voxels:7635
All_voxels:240*240*155 | numerator:0 | denominator:6047 | pred_voxels:0 | GT_voxels:6047
whole: 0.8576, core: 0.0496, enhancing: 0.0000


In [66]:
p.unique()

tensor([0., 1., 2.])