Predict script for Merge model, predict the D and R matrices, and visualize the result.

In [None]:
import os

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
from modules.merge_modules import MergeModel
import torch.backends.cudnn as cudnn
import torch
import json
from dataset.dataset import ImageDataset
from PIL import Image
import cv2
import numpy as np

In [None]:
# init the Merge model
net = MergeModel(3).cuda()
cudnn.benchmark = True
cudnn.deterministic = True
net = torch.nn.DataParallel(net).cuda()

In [None]:
# load saved checkpoint 
net.load_state_dict(torch.load('Merge_model.pth'))

In [None]:
# change the model to eval mode
net.eval()

In [None]:
# init dataset
folder = 'validation'
with open('D:/dataset/table/table_line/Split1/'+ folder+'_merge_dict.json', 'r') as f:
    labels = json.load(f)
dataset = ImageDataset('D:/dataset/table/table_line/Split1/'+ folder+'_input', labels, 8, scale=0.25,mode='merge')

In [None]:
index = 0
img, label, arc = dataset[index]
index += 1

In [None]:
# predict 
input_img = img.unsqueeze(0)
arc_c = [[torch.Tensor([y]) for y in x] for x in arc]
pred = net(input_img,arc_c)
u,d,l,r = pred # up, down, left, right
# calculate D and R matrice, 
D = 0.5 * u[:, :-1, :] * d[:, 1:, :] + 0.25 * (u[:, :-1, :] + d[:, 1:, :])
R = 0.5 * r[:, :, :-1] * l[:, :, 1:] + 0.25 * (r[:, :, :-1] + l[:, :, 1:])
D = D[0].detach().cpu().numpy()
R = R[0].detach().cpu().numpy()
D[D>0.5] = 1
D[D<=0.5] = 0
R[R>0.5] = 1
R[R<=0.5] = 0

rows, columns = arc
h,w = img[2].shape
rows = [round(h*x) for x in rows]
columns = [round(w*x) for x in columns]
rows = [0] + rows + [h]
columns = [0] + columns +[w]

# draw lines on the original image
draw_img = img[2].numpy()*255.
draw_img = cv2.cvtColor(draw_img, cv2.COLOR_GRAY2RGB)
for i in range(R.shape[0]):
    for j in range(R.shape[1]):
        if R[i,j] == 0:
            pts1 = (columns[j+1],rows[i])
            pts2 = (columns[j+1],rows[i+1])
            draw_img = cv2.line(draw_img, pts1,pts2,(255.,0,0),2)
for i in range(D.shape[0]):
    for j in range(D.shape[1]):
        if D[i,j] == 0:
            pts1 = (columns[j],rows[i+1])
            pts2 = (columns[j+1],rows[i+1])
            draw_img = cv2.line(draw_img, pts1,pts2,(255.,0,0),2)

In [None]:
# visualize original image
Image.fromarray(img[2].numpy()*255.).convert('L')

In [None]:
# visualize merged image
Image.fromarray(np.array(draw_img,dtype=np.uint8))