In [1]:
import os
import sys
import random
import torch
import torchvision
import pytesseract
from pdf2image import convert_from_path
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms import transforms

import cv2
import numpy as np

from utils import (
    overlay_ann,
    overlay_mask,
    show,
    extract_elements
)

In [2]:
seed = 1234
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


CATEGORIES2LABELS = {
    0: "bg",
    1: "text",
    2: "title",
    3: "list",
    4: "table",
    5: "figure"
}

In [3]:
def get_instance_segmentation_model(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256

    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )
    return model

In [4]:
def get_page_elements(image_path):
    # model 
    num_classes = 6
    model = get_instance_segmentation_model(num_classes)
    model.cuda()

    if os.path.exists('model_196000.pth'):
        checkpoint_path = "model_196000.pth"
    else:
        checkpoint_path = "../../../Downloads/model_196000.pth"

    assert os.path.exists(checkpoint_path)
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    model.eval()
    
    # NOTE: custom  image
    assert os.path.exists(image_path)
    image_name = (image_path.split('/')[-1]).split('.')[0]
    print(image_path)
    elements_path = f'{image_name}_elements'
    try:
        os.mkdir(elements_path)
    except:
        print ("Creation of the directory %s failed" % elements_path)
    else:
        print("Successfully created the directory %s " % elements_path)
        
        
    image = cv2.imread(image_path)
    rat = 1300 / image.shape[0]
    image = cv2.resize(image, None, fx=rat, fy=rat)

    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor()
    ])
    image = transform(image)

    with torch.no_grad():
        prediction = model([image.cuda()])

    image = torch.squeeze(image, 0).permute(1, 2, 0).mul(255).numpy().astype(np.uint8)
    
    ROI_number = 0 
    for pred in prediction:
        for idx, mask in enumerate(pred['masks']):
            if pred['scores'][idx].item() < 0.7:
                continue

            m = mask[0].mul(255).byte().cpu().numpy()
            box = list(map(int, pred["boxes"][idx].tolist()))
            label = CATEGORIES2LABELS[pred["labels"][idx].item()]

            score = pred["scores"][idx].item()

            # image = overlay_mask(image, m)
            extract_elements(image, box, label, ROI_number, elements_path)
            image = overlay_ann(image, m, box, label, score)
            ROI_number += 1
    image_save_path = (f'./{elements_path}/masked_{os.path.basename(image_path)}')
    cv2.imwrite(image_save_path, image)
    # show(image)

In [5]:
def main(argv):
    if len(argv) > 0 and os.path.exists(argv[0]):
        file_path = argv[0]
    else:
        file_path = 'CVPR2017.pdf'
    
    pages = convert_from_path(file_path, dpi=200)
    file_name = file_path.split('.')[0]
    try:
        os.mkdir(f'../tmp/images/{file_name}')
    except:
        print ("Creation of the directory %s failed" % f'../tmp/images/{file_name}')
    else:
        print("Successfully created the directory %s " % f'../tmp/images/{file_name}')
    
    for idx,page in enumerate(pages):
        page.save(f"../tmp/images/CVPR2017/{idx}.png", 'PNG')
        image_path = f'../tmp/images/CVPR2017/{idx}.png'
        get_page_elements(image_path)    

In [6]:
if __name__ == "__main__":
    import sys
    argv = sys.argv[1:]
    main(argv)

PDFPageCountError: Unable to get page count.
I/O Error: Couldn't open file 'CVPR2017.pdf': No error.
