In [1]:
import torch
import cv2
from segmentation_models_pytorch import Unet
import os
from card_segmentation.utils import image
import numpy as np
import matplotlib.pyplot as plt
import shutil

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_PATH = "./card_segmentation/pretrained/model_final.pt"
DATA_PATH  = "./data/MIDV500"
DATABASE   = "./database"
PREDICTION = "./prediction"

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
model = Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
)

checkpoint = torch.load(MODEL_PATH)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()

  checkpoint = torch.load(MODEL_PATH)


Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [5]:
def predict_image(model, image):
    with torch.no_grad():
        output = model(image.to(device))

    output = output.detach().cpu().numpy()[0]
    output = output.transpose((1, 2, 0))
    output = np.uint8(output)
    _, output = cv2.threshold(output, 127, 255, cv2.THRESH_BINARY_INV)

    return output

In [6]:
def crop_image(image_dir):
    img, w, h = image.load_image(image_dir)
    mask = predict_image(model, img)
    mask = cv2.resize(mask, (w, h))
    raw_image = cv2.imread(image_dir)
    warped = image.extract_idcard(raw_image, mask)
    return warped

In [7]:
failed_segment = []

In [8]:
for i in range(len(os.listdir(DATA_PATH))):
    label = str(i+1)
    label_dir = DATA_PATH + '/' + label
    dst_dir = PREDICTION + '/' + label
    if os.path.exists(dst_dir):
        shutil.rmtree(dst_dir)
    os.makedirs(dst_dir)

    for img in os.listdir(label_dir):
        img_dir = label_dir + '/' + img
        warped_img = crop_image(img_dir)
        if warped_img is not None:
            warped_img = cv2.resize(warped_img, (256, 256))
            warped_img = cv2.cvtColor(warped_img, cv2.COLOR_BGR2RGB)
            cv2.imwrite(dst_dir + '/' + img, warped_img)
        else:
            failed_segment.append(img_dir)