In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import imageio
import tensorflow as tf
from tqdm import tqdm

%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 6)

MODELS_DIR = "/home/mel/datasets/outmark/"

In [None]:
graph = tf.Graph()
session = tf.Session(graph=graph)

models = ["abdomen", "lits"]

with graph.as_default():
    for model in models:
        tf.train.import_meta_graph(MODELS_DIR + model + ".graphdef", import_scope=model)

    session.run(tf.global_variables_initializer())

In [None]:
def normalize_image(image):
    image = image.astype(np.float32)
    background = np.abs(np.min(image))
    if background > 0:
        image /= background
    return image

def extract_features(image):
    image = normalize_image(image)

    all_results = []
    for model in models:
        inputs = graph.get_tensor_by_name(model + "/X:0")
        is_training = graph.get_tensor_by_name(model + "/is_training_1:0")
        keep_prob = graph.get_tensor_by_name(model + "/keep_prob:0")
        features = graph.get_tensor_by_name(model + "/features:0")

        results = []
        for z in range(image.shape[0]):
            X = np.expand_dims(np.expand_dims(image[z, :, :], axis=0), axis=1)
            f, = session.run([features], feed_dict={inputs: X, is_training: False, keep_prob: 0.5})
            
            results.append(f[0, 0, :, :])

        results = np.stack(results, axis=0)
        all_results.append(results)

    return np.concatenate(all_results, axis = -1)

assert extract_features(np.random.normal(size=(5,32,32))).shape[0:3] == (5, 32, 32)

In [None]:
DSBOWL_DIR = "/large/datasets/dsbowl2018/"

train_images = []
train_masks = []

for item in tqdm(os.listdir(DSBOWL_DIR + "/stage1_train/")):
    basedir = DSBOWL_DIR + "/stage1_train/" + item
    imagesdir = basedir + "/images/"
    masksdir = basedir + "/masks/"

    image = imageio.imread(imagesdir + item + ".png")
    masks = [imageio.imread(masksdir + mask) for mask in os.listdir(masksdir)]
        
    train_images.append(image)
    train_masks.append(masks)

In [None]:
x = train_images[0]/255.0
x = np.swapaxes(x, 0, 2)

f = extract_features(x)

fig=plt.figure(figsize=(12, 12))
nrows = 2
ncols = 4
for i in range(ncols):
    fig.add_subplot(nrows, ncols, 1 + i)
    plt.imshow(x[i, :, :], cmap='gray')
    
    fig.add_subplot(nrows, ncols, 1 + i + ncols)
    plt.imshow(f[i, :, :, 0], cmap='gray')

fig.show()