In [1]:
%matplotlib widget

In [2]:
from pathlib import Path

import torch

import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

from niftynet.io.image_reader import ImageReader
from niftynet.utilities.util_common import ParserNamespace
from niftynet.io.image_sets_partitioner import ImageSetsPartitioner
from niftynet.layer.histogram_normalisation import (
    HistogramNormalisationLayer,
)
from niftynet.layer.binary_masking import (
    BinaryMaskingLayer,
)

from highresnet import HighRes3DNet

INFO:tensorflow:TensorFlow version 1.10.0
CRITICAL:tensorflow:Optional Python module cv2 not found, please install cv2 and retry if the application fails.
CRITICAL:tensorflow:Optional Python module skimage.io not found, please install skimage.io and retry if the application fails.
INFO:tensorflow:Available Image Loaders:
['nibabel', 'pillow', 'simpleitk', 'dummy'].
[1mINFO:niftynet:[0m Optional Python module yaml not found, please install yaml and retry if the application fails.
[1mINFO:niftynet:[0m Optional Python module yaml version None not found, please install yaml-None and retry if the application fails.


In [3]:
state_dict_pt_path = '/tmp/state_dict_pt.pth'

## Standardisation

In [4]:
def preprocess(input_path,
               model_path,
               output_path,
               cutoff,
    ):
    input_path = Path(input_path)
    output_path = Path(output_path)
    input_dir = input_path.parent
    
    DATA_PARAM = {
        'Modality0': ParserNamespace(
            path_to_search=str(input_dir),
            filename_contains=('nii.gz',),
            interp_order=0,
            pixdim=None,
            axcodes='RAS',
            loader=None,
        )
    }

    TASK_PARAM = ParserNamespace(image=('Modality0',))
    data_partitioner = ImageSetsPartitioner()
    file_list = data_partitioner.initialise(DATA_PARAM).get_file_list()
    reader = ImageReader(['image'])
    reader.initialise(DATA_PARAM, TASK_PARAM, file_list)
    
    binary_masking_func = BinaryMaskingLayer(
        type_str='mean_plus',
    )
    
    hist_norm = HistogramNormalisationLayer(
        image_name='image',
        modalities=['Modality0'],
        model_filename=str(model_path),
        binary_masking_func=binary_masking_func,
        cutoff=cutoff,
        name='hist_norm_layer',
    )
    
    image = reader.output_list[0]['image']
    data = image.get_data()
    norm_image_dict, mask_dict = hist_norm({'image': data})
    data = norm_image_dict['image']
    nii = nib.Nifti1Image(data.squeeze(), image.original_affine[0])
    dst = output_path
    nii.to_filename(str(dst))

input_path = '/tmp/pt/reoriented/reoriented.nii.gz'
model_path = '/home/fernando/niftynet/models/highres3dnet_brain_parcellation/databrain_std_hist_models_otsu.txt'
output_path = Path('/tmp/hist.nii.gz')
cutoff = (0.001, 0.999)  # from .ini
preprocess(input_path, model_path, output_path, cutoff)

[1mINFO:niftynet:[0m 

Number of subjects 1, input section names: ['subject_id', 'Modality0']
-- using all subjects (without data partitioning).

[1mINFO:niftynet:[0m Image reader: loading 1 subjects from sections ('Modality0',) as input [image]


## Whitening

In [11]:
nii = nib.load(str(output_path))
data = nii.get_data()
fg_only = False
if fg_only:
    mask = data > data.mean()
    values = data[mask]
else:
    values = data
data -= values.mean()
data /= values.std()

In [12]:
def plot_volume(array):
    def turn(s):
        return np.fliplr(np.rot90(s))
    fig, axes = plt.subplots(1, 3)
    si, sj, sk = array.shape
    axes[0].imshow(turn(array[si//2, ...]), cmap='gray')
    axes[1].imshow(turn(array[:, sj//2, :]), cmap='gray')
    axes[2].imshow(turn(array[..., sk//2]), cmap='gray')
    plt.tight_layout()

In [13]:
# plot_volume(data)
size = 128
center = np.array((80, 116, 128))
ini = center - size // 2
fin = ini + size
roi = data[ini[0]:fin[0], ini[1]:fin[1], ini[2]:fin[2]]
plot_volume(roi)
plt.show()

FigureCanvasNbAgg()

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict_pt = torch.load(state_dict_pt_path)
model = HighRes3DNet(1, 160, add_dropout_layer=True)
model.load_state_dict(state_dict_pt)
model = model.to(device)
tensor = torch.tensor(roi)
tensor = tensor.unsqueeze(0)  # channels dimension
batch = tensor.unsqueeze(0)  # batch dimension
batch = batch.to(device)

model.eval()
with torch.no_grad():
    output = model(batch)

del model

torch.Size([1, 16, 128, 128, 128])
134.217728 MB

torch.Size([1, 16, 128, 128, 128])
134.217728 MB

torch.Size([1, 16, 128, 128, 128])
134.217728 MB

torch.Size([1, 16, 128, 128, 128])
134.217728 MB

torch.Size([1, 16, 128, 128, 128])
134.217728 MB

torch.Size([1, 16, 128, 128, 128])
134.217728 MB

torch.Size([1, 16, 128, 128, 128])
134.217728 MB

torch.Size([1, 32, 128, 128, 128])
268.435456 MB

torch.Size([1, 32, 128, 128, 128])
268.435456 MB

torch.Size([1, 32, 128, 128, 128])
268.435456 MB

torch.Size([1, 32, 128, 128, 128])
268.435456 MB

torch.Size([1, 32, 128, 128, 128])
268.435456 MB

torch.Size([1, 32, 128, 128, 128])
268.435456 MB

torch.Size([1, 64, 128, 128, 128])
536.870912 MB

torch.Size([1, 64, 128, 128, 128])
536.870912 MB

torch.Size([1, 64, 128, 128, 128])
536.870912 MB

torch.Size([1, 64, 128, 128, 128])
536.870912 MB

torch.Size([1, 64, 128, 128, 128])
536.870912 MB

torch.Size([1, 64, 128, 128, 128])
536.870912 MB

torch.Size([1, 80, 128, 128, 128])
671.08864 MB

t

In [15]:
labels = output.argmax(dim=1).squeeze().cpu().numpy()
plot_volume(roi)
plot_volume(labels)

FigureCanvasNbAgg()

FigureCanvasNbAgg()

In [10]:
label_map = np.zeros_like(data, dtype=np.uint16)
label_map[ini[0]:fin[0], ini[1]:fin[1], ini[2]:fin[2]] = labels
label_nii = nib.Nifti1Image(label_map, nii.affine)
label_nii.to_filename('/tmp/labels.nii.gz')