In [None]:
"""Predicting Module."""

from collections import OrderedDict
from typing import List

import click
import numpy as np
import pandas as pd
from albumentations import Compose
import albumentations as album
from albumentations.pytorch.transforms import ToTensorV2
from PIL import Image
import easyocr
from skimage.filters import threshold_otsu
from skimage.segmentation import clear_border
from skimage.measure import label, regionprops
from skimage.morphology import closing, square, convex_hull_image
from skimage.transform import resize
from skimage.util import invert

from tablenet import TableNetModule


class Predict:
    """Predict images using pre-trained model."""

    def __init__(self, checkpoint_path: str, transforms: Compose, threshold: float = 0.5, per: float = 0.005):
        """Predict images using pre-trained TableNet model.

        Args:
            checkpoint_path (str): model weights path.
            transforms (Optional[Compose]): Compose object from albumentations used for pre-processing.
            threshold (float): threshold to consider the value as correctly classified.
            per (float): Minimum area for tables and columns to be considered.
        """
        self.transforms = transforms
        self.threshold = threshold
        self.per = per

        self.model = TableNetModule.load_from_checkpoint(checkpoint_path)
        self.model.eval()
        self.model.requires_grad_(False)
        self.reader = easyocr.Reader(['ru'])

    def predict(self, image: Image) -> List[pd.DataFrame]:
        """Predict a image table values.

        Args:
            image (Image): PIL.Image to

        Returns (List[pd.DataFrame]): Tables in pandas DataFrame format.
        """
        processed_image = self.transforms(image=np.array(image))["image"]

        table_mask, column_mask = self.model.forward(processed_image.unsqueeze(0))

        table_mask = self._apply_threshold(table_mask)
        column_mask = self._apply_threshold(column_mask)

        segmented_tables = self._process_tables(self._segment_image(table_mask))

        tables = []
        for table in segmented_tables:
            segmented_columns = self._process_columns(self._segment_image(column_mask * table))
            if segmented_columns:
                cols = []
                for column in segmented_columns.values():
                    cols.append(self._column_to_dataframe(column, image, self.reader))
                tables.append(pd.concat(cols, ignore_index=True, axis=1))
        return tables

    def _apply_threshold(self, mask):
        mask = mask.squeeze(0).squeeze(0).numpy() > self.threshold
        return mask.astype(int)

    def _process_tables(self, segmented_tables):
        width, height = segmented_tables.shape
        tables = []
        for i in np.unique(segmented_tables)[1:]:
            table = np.where(segmented_tables == i, 1, 0)
            if table.sum() > height * width * self.per:
                tables.append(convex_hull_image(table))

        return tables

    def _process_columns(self, segmented_columns):
        width, height = segmented_columns.shape
        cols = {}
        for j in np.unique(segmented_columns)[1:]:
            column = np.where(segmented_columns == j, 1, 0)
            column = column.astype(int)

            if column.sum() > width * height * self.per:
                position = regionprops(column)[0].centroid[1]
                cols[position] = column
        return OrderedDict(sorted(cols.items()))

    @staticmethod
    def _segment_image(image):
        thresh = threshold_otsu(image)
        bw = closing(image > thresh, square(2))
        cleared = clear_border(bw)
        label_image = label(cleared)
        return label_image

    @staticmethod
    def _column_to_dataframe(column, image, reader):
        width, height = image.size
        column = resize(np.expand_dims(column, axis=2), (height, width), preserve_range=True) > 0.01

        crop = column * image
        white = np.ones(column.shape) * invert(column) * 255
        crop = crop + white
        ocr = reader.readtext(crop.astype(np.uint8))
        return pd.DataFrame({"col": [value[1] for value in ocr if len(value) > 0]})

In [None]:
# проверка пар, сбор списка пар
rootdir = "./datasets/405/"
regex_find = re.compile('(.*csv$)|(.*jpg$)')
jpg_list = []
csv_list = []

for root, dirs, files in os.walk(rootdir):
    for file in files:
        if regex_find.match(file):
            temp = os.path.splitext(file)
            if temp[1] == '.jpg':
                name = temp[0] + temp[1]
                clean_name = temp[0].strip() + temp[1]
                if name != clean_name:
                    os.rename(rootdir+'/'+name, rootdir+'/'+clean_name) # убираем лишние пробелы в названиях фото
                jpg_list.append(temp[0])
            else: csv_list.append(temp[0])

for x in jpg_list:
    if x not in csv_list:
        print(f'Отсутствует csv для {x}.jpg')

for x in csv_list:
    if x not in jpg_list:
        print(f'Отсутствует jpg для {x}.csv')

pairs_list = list(set(jpg_list) & set(csv_list))

In [None]:
model_weights = "./tablenet/ocr_model.ckpt"

transforms = album.Compose([
    album.Resize(896, 896, always_apply=True),
    album.Normalize(),
    ToTensorV2()
])

for x in pairs_list:
    try:
        pred = Predict(model_weights, transforms)
        image_path = './datasets/405/'+x+'.jpg'
        image = Image.open(image_path)
        print(pred.predict(image))
    except: print(x, 'не обработан')