In [1]:
from sheet_id.models.FCN import FCN
from sheet_id.utils.loss_functions import softmax_sparse_crossentropy_ignoring_background, softmax_sparse_crossentropy
from sheet_id.utils.metrics import sparse_accuracy_ignoring_background, sparse_accuracy
from sheet_id.utils.dataPreprocessing import splitTrainValidation
from sheet_id.utils.dataGenerator import DataGenerator
from sheet_id.utils.dwd_utils import generateGroundTruthMaps
from sheet_id.utils.base_utils import generateSheetMaskAnnotation
from sheet_id.utils.eval_utils import evaluate

from keras.layers.convolutional import Conv2D
from keras.layers import Input
from keras.models import Model
from keras.optimizers import Adam

import glob
import math
import os
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
import random
from copy import deepcopy

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
model = FCN(input_shape=(500,500,1), n_classes=124)
input_map = Input(shape=(500,500,1))
output_featuremaps = model(input_map)
energy_output = Conv2D(10,  (1,1), activation='relu', padding='same', name='energy_map')(output_featuremaps)
class_output  = Conv2D(124, (1,1), activation='relu', padding='same', name='class_map')(output_featuremaps)
bbox_output   = Conv2D(2,   (1,1), activation='relu', padding='same', name='bbox_map')(output_featuremaps)
dwd_model = Model(inputs=[input_map], outputs=[energy_output, class_output, bbox_output])
loss_fn = softmax_sparse_crossentropy
optimizer = Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, 
                 epsilon=1e-8, decay=0.0, amsgrad=False)
metrics = []
dwd_model.compile(loss={
                            "energy_map": loss_fn,
                            "class_map": loss_fn,
                            "bbox_map": "mse",
                        }, 
                  loss_weights={'energy_map': 1.0, 'class_map': 1.0, 'bbox_map': 0.25},
                  optimizer=optimizer, metrics=metrics)
dwd_model.load_weights('../checkpoints/dwd-finetune-5000.h5')

In [103]:
img = cv2.imread("/home/mirlab/prepped_tif/bach_bwv8491_v1/bach_bwv8491_v1-1.tif", 0)

In [109]:
window_size = (500,500)
step_size = (450,450)
(n_rows, n_cols) = img.shape
n_steps_r = 1 + math.ceil((n_rows - window_size[0]) / step_size[0])
n_steps_c = 1 + math.ceil((n_cols - window_size[1]) / step_size[1])

energy_map_final = np.empty((window_size[0] + (n_steps_r - 1) * step_size[0], window_size[1] + (n_steps_c - 1) * step_size[1]))
class_map_final = np.empty((window_size[0] + (n_steps_r - 1) * step_size[0], window_size[1] + (n_steps_c - 1) * step_size[1]))
bbox_map_final = np.empty((window_size[0] + (n_steps_r - 1) * step_size[0], window_size[1] + (n_steps_c - 1) * step_size[1], 2))

for i in range(n_steps_r):
    for j in range(n_steps_c):
        # Load patch
        start_row = i * step_size[0]
        start_col = j * step_size[1]
        end_row = min(n_rows, start_row + window_size[0])
        end_col = min(n_cols, start_col + window_size[1])
        
        img_patch = np.ones(window_size) * 255
        img_patch[:(end_row-start_row), :(end_col-start_col)] = img[start_row:end_row, start_col:end_col]
        
        # Model prediction + conversion
        energy_map, class_map, bbox_map = dwd_model.predict(img_patch.reshape(1, window_size[0], window_size[1], 1))
        energy_map_binarized = 255 * (np.argmax(energy_map, axis=-1)[0] < 5) # binarize image
        class_prediction_img = np.argmax(class_map, axis=-1)[0]
        bbox_prediction      = bbox_map[0,:,:,:]
        
        energy_map_final[start_row:start_row+window_size[0], start_col:start_col+window_size[1]] = energy_map_binarized
        class_map_final[start_row:start_row+window_size[0], start_col:start_col+window_size[1]] = class_prediction_img
        bbox_map_final[start_row:start_row+window_size[0], start_col:start_col+window_size[1], :] = bbox_prediction
        
        
energy_map_final = energy_map_final[:n_rows, :n_cols]
class_map_final = class_map_final[:n_rows, :n_cols]
bbox_map_final = bbox_map_final[:n_rows, :n_cols, :]