In [1]:
import h5py
import torch
from matplotlib import pyplot as plt
import os, sys, numpy as np
import torch
from sononet.utils.util import json_file_to_pyobj
from skimage.transform import resize
from sononet.models import get_model
import cv2

In [68]:
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = None

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout = self._original_stdout


def crop_image(image, crop_range):
    return image[crop_range[0][0]:crop_range[0][1], crop_range[1][0]:crop_range[1][1], ...]


def image_loader(image_name, image_size):
    """load image, returns cuda tensor"""
    image = cv2.imread(image_name)
    image= cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    im_array = resize(image, (int(image_size[0]), int(image_size[1] )), preserve_range=True)
    image_T = torch.from_numpy(im_array)
    image_T = image_T.type(torch.FloatTensor)
    image_T = image_T.unsqueeze(0)
    image_T = image_T.unsqueeze(0)
    image_T_norm = image_T.sub(image_T.mean()).div(image_T.std())
    return image_T_norm.cuda()


def test_image(image):
        model.set_input(image)
        model.net.eval()
        with torch.no_grad():
            model.forward(split='test')
        scores = model.logits.data.cpu()
        scores = scores.numpy()
        pr_lbls = model.pred
        return scores, pr_lbls

In [5]:
checkpoint_file = 'sononet/checkpoints/300_net_S.pth'
json_filename = 'sononet/config_sononet_8.json'
json_opts = json_file_to_pyobj(json_filename)

In [6]:
with HiddenPrints():
    model = get_model(json_opts.model)

if hasattr(model.net, 'deep_supervised'):
    model.net.deep_supervised = False

    # Load checkpoint
if os.path.isfile(checkpoint_file):
    checkpoint = torch.load(checkpoint_file)
    model.net.load_state_dict(checkpoint)
    print("=> Loaded checkpoint '{}'".format(checkpoint_file))
else:
    print("=> No checkpoint found at '{}'!!!!!!".format(checkpoint_file))

=> Loaded checkpoint 'sononet/checkpoints/300_net_S.pth'


In [29]:
foo = h5py.File('data/image_inp_224x288.hdf5')

In [49]:
list(foo.keys())

['annotation_labels_test',
 'annotation_labels_train',
 'data_mean',
 'domain_labels_test',
 'domain_labels_train',
 'images_test',
 'images_train',
 'label_names',
 'label_numbers',
 'mean_image',
 'plane_labels_test',
 'plane_labels_train']

In [145]:
dict(zip(foo['label_numbers'].value,foo['label_names'].value))

{0: b'3VT',
 1: b'3VV',
 2: b'4CH',
 3: b'AA',
 4: b'ABDOMINAL',
 5: b'ARMS',
 6: b'BLADDER',
 7: b'BRAIN_BIPARIETAL',
 8: b'BRAIN_POSTFOSSA',
 9: b'CERVIX',
 10: b'CORD_INSERT',
 11: b'CX-PL',
 12: b'DIAPHRAGM',
 13: b'FACE_CORONAL',
 14: b'FEET',
 15: b'FEMUR',
 16: b'FHR',
 17: b'HANDS',
 18: b'IVS',
 19: b'KIDNEYS',
 20: b'LEGS',
 21: b'LIPS',
 22: b'LVOT',
 23: b'NECK',
 24: b'NT',
 25: b'ORBITS',
 26: b'OTHER',
 27: b'PLACENTA',
 28: b'PROFILE',
 29: b'RVOT',
 30: b'SEPTUM',
 31: b'SPINE',
 32: b'STOMACH'}

In [148]:
np.unique(foo['plane_labels_train'].value, return_counts=True)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],
       dtype=uint8),
 array([  169,   280,  1025,    15,   927,    60,   107,  2068,  1348,
          142,   279,   124,   294,   141,     9,  1066,    20,    59,
           12,   401,    32,  1162,   332,    42,    40,    33, 23903,
          180,   177,   229,    90,   918,     8], dtype=int64))

In [None]:
standard_planes = [1,2,4,15,19,21,22,28,29]

In [143]:
foo['label_numbers'].value

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],
      dtype=int64)

In [62]:
images_test = torch.tensor(foo['images_test'])

In [114]:
means = images_test.view(len(images_test),-1).float().mean(dim=1).view(-1,1,1,1)
stds = images_test.view(len(images_test),-1).float().std(dim=1).view(-1,1,1,1)
images_test_transformed = images_test.sub(means).div(stds)

In [115]:
testloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(images_test_transformed),
    shuffle=False,
    batch_size=50
)

In [123]:
for (b,) in testloader:
    print(test_image(b))
    break

(array([[3.12809706e-01, 1.47321627e-01, 5.56317717e-03, 4.19376604e-02,
        2.52200814e-04, 4.22958110e-04, 4.03672410e-03, 2.93651540e-02,
        5.86043997e-03, 1.64447486e-01, 1.44632428e-03, 2.81185597e-01,
        2.65429053e-03, 2.69663567e-03],
       [1.54532194e-01, 3.35095555e-01, 9.57720913e-04, 3.39347608e-02,
        1.68230000e-03, 9.29359521e-04, 1.21004134e-03, 7.25757622e-04,
        4.91103297e-03, 2.52350628e-01, 5.23692404e-04, 2.12490037e-01,
        4.10273846e-04, 2.46615586e-04],
       [6.46811366e-01, 1.65909715e-02, 4.47998937e-05, 1.95045723e-04,
        1.34857637e-05, 2.47830121e-06, 7.45807984e-06, 1.53581332e-05,
        7.37006621e-06, 4.42273170e-02, 5.38941777e-06, 2.92049706e-01,
        1.97639329e-05, 9.47838998e-06],
       [6.00976408e-01, 4.24075462e-02, 1.32748170e-03, 3.50554776e-03,
        6.28063353e-05, 6.77888456e-05, 1.20988050e-04, 3.07375944e-04,
        1.02016609e-04, 3.73296663e-02, 2.28757475e-04, 3.12900364e-01,
        6.35