In [None]:
%matplotlib inline 

In [None]:
import random

from matplotlib import pyplot as plt
import cv2
from tensorpack.tfutils import SmartInit
from tensorpack.predict import OfflinePredictor, PredictConfig

from config import config, finalize_configs
from dataset.fintabnet import register_fintabnet
from dataset import DatasetRegistry
from modeling.generalized_rcnn import ResNetC4Model, ResNetFPNModel
from eval import predict_image
from viz import draw_final_outputs

In [None]:
register_fintabnet(config.DATA.BASEDIR)
finalize_configs(False)

In [None]:
ds = DatasetRegistry.get('fintabnet_val')
roidbs = ds.inference_roidbs()
random.seed(42)
random.shuffle(roidbs)
print("#images:", len(roidbs))

In [None]:
model_file_path = 'train_log/fpn_v2/model-1440000'
MODEL = ResNetFPNModel() if config.MODE_FPN else ResNetC4Model()
config.TEST.RESULT_SCORE_THRESH = config.TEST.RESULT_SCORE_THRESH_VIS
predcfg = PredictConfig(
            model=MODEL,
            session_init=SmartInit(model_file_path),
            input_names=MODEL.get_inference_tensor_names()[0],
            output_names=MODEL.get_inference_tensor_names()[1])
predictor = OfflinePredictor(predcfg)

In [None]:
fig, axes = plt.subplots(6, 2, figsize=(16,64))

for r, ax in zip(roidbs, axes.flat):
    page_image = cv2.imread(r["file_name"])
    predictions = predict_image(page_image, predictor)
    debug_image = draw_final_outputs(page_image, predictions)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.imshow(debug_image)
plt.savefig('.github/fpn_predictions.png', bbox_inches='tight')
plt.show()