In [1]:
import argparse
import cv2
import numpy as np
import os
import pathlib

import torch

import models
from utils import image

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
INPUT_FILE = "dataset/train/train_frames/image/image2.png"
# INPUT_FILE = "test1/image0.jpg"
OUTPUT_MASK = "output_mask.png"
OUTPUT_FILE = "output_pred.png"
MODEL_FILE = "pretrained/model_checkpoint.pt"

In [4]:
def predict_image(model, image):
    with torch.no_grad():
        output = model(image.to(device))

    output = output.detach().cpu().numpy()[0]
    output = output.transpose((1, 2, 0))
    output = np.uint8(output)
    _, output = cv2.threshold(output, 127, 255, cv2.THRESH_BINARY_INV)

    return output


In [5]:
model = models.UNet(n_channels=1, n_classes=1)

checkpoint = torch.load(pathlib.Path(MODEL_FILE))
model.load_state_dict(checkpoint['model_state_dict'])
# model.load_state_dict(checkpoint)
model.to(device)
model.eval()

  checkpoint = torch.load(pathlib.Path(MODEL_FILE))


UNet(
  (conv1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (conv3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, trac

In [6]:
img, h, w = image.load_image(INPUT_FILE)

print('Prediction...')
output_image = predict_image(model, img)

print('Resize mask to original size...')
mask_image = cv2.resize(output_image, (w, h))
cv2.imwrite(OUTPUT_MASK, mask_image)

Prediction...
Resize mask to original size...


True

In [7]:
raw_image = cv2.imread(INPUT_FILE)
warped = image.extract_idcard(raw_image, mask_image)
warped = cv2.cvtColor(warped, cv2.COLOR_BGR2RGB)
cv2.imwrite(OUTPUT_FILE, warped)

AssertionError: 