In [1]:
import h5py
import torch
from matplotlib import pyplot as plt
import os, sys, numpy as np
import torch
from sononet.utils.util import json_file_to_pyobj
from skimage.transform import resize
from sononet.models import get_model
import cv2

In [15]:
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = None

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout = self._original_stdout


def crop_image(image, crop_range):
    return image[crop_range[0][0]:crop_range[0][1], crop_range[1][0]:crop_range[1][1], ...]


def image_loader(image_name, image_size):
    """load image, returns cuda tensor"""
    image = cv2.imread(image_name)
    image= cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    im_array = resize(image, (int(image_size[0]), int(image_size[1] )), preserve_range=True)
    image_T = torch.from_numpy(im_array)
    image_T = image_T.type(torch.FloatTensor)
    image_T = image_T.unsqueeze(0)
    image_T = image_T.unsqueeze(0)
    image_T_norm = image_T.sub(image_T.mean()).div(image_T.std())
    return image_T_norm.cuda()


def test_image(image):
        model.set_input(image)
        model.net.eval()
        with torch.no_grad():
            model.forward(split='test')
        scores = model.logits.data[0].cpu()
        scores = scores.numpy()
        pr_lbls = model.pred[1].item()
        return scores, pr_lbls

In [5]:
checkpoint_file = 'sononet/checkpoints/300_net_S.pth'
json_filename = 'sononet/config_sononet_8.json'
json_opts = json_file_to_pyobj(json_filename)

In [6]:
with HiddenPrints():
    model = get_model(json_opts.model)

if hasattr(model.net, 'deep_supervised'):
    model.net.deep_supervised = False

    # Load checkpoint
if os.path.isfile(checkpoint_file):
    checkpoint = torch.load(checkpoint_file)
    model.net.load_state_dict(checkpoint)
    print("=> Loaded checkpoint '{}'".format(checkpoint_file))
else:
    print("=> No checkpoint found at '{}'!!!!!!".format(checkpoint_file))

=> Loaded checkpoint 'sononet/checkpoints/300_net_S.pth'


In [16]:
foo = h5py.File('data/image_inp_224x288.hdf5')

In [17]:
bar = torch.from_numpy(foo['images_test'][901,0]).float()

In [18]:
image = bar.unsqueeze(0).unsqueeze(0).sub(bar.mean()).div(bar.std())

In [19]:
test_image(image)

(array([7.24119809e-06, 5.76486018e-05, 9.37142904e-05, 2.02036276e-03,
        4.81211841e-02, 9.47071791e-01, 2.20068311e-03, 3.98239354e-05,
        6.35359829e-05, 1.51838285e-05, 2.17778594e-04, 9.88415104e-06,
        5.48396456e-05, 2.63970269e-05], dtype=float32), 5)